GODT-1815: Combined/Split mode

This commit is contained in:
James Houlahan
2022-09-28 11:29:33 +02:00
parent 9670e29d9f
commit e9672e6bba
55 changed files with 1909 additions and 705 deletions

View File

@ -0,0 +1,46 @@
package user
import "gitlab.protontech.ch/go/liteapi"
type addrList struct {
apiAddrs ordMap[string, string, liteapi.Address]
}
func newAddrList(apiAddrs []liteapi.Address) *addrList {
return &addrList{
apiAddrs: newOrdMap(
func(addr liteapi.Address) string { return addr.ID },
func(addr liteapi.Address) string { return addr.Email },
func(a, b liteapi.Address) bool { return a.Order < b.Order },
apiAddrs...,
),
}
}
func (list *addrList) insert(address liteapi.Address) {
list.apiAddrs.insert(address)
}
func (list *addrList) delete(addrID string) string {
return list.apiAddrs.delete(addrID)
}
func (list *addrList) primary() string {
return list.apiAddrs.keys()[0]
}
func (list *addrList) addrIDs() []string {
return list.apiAddrs.keys()
}
func (list *addrList) emails() []string {
return list.apiAddrs.values()
}
func (list *addrList) email(addrID string) string {
return list.apiAddrs.get(addrID)
}
func (list *addrList) addrMap() map[string]string {
return list.apiAddrs.toMap()
}

View File

@ -2,16 +2,20 @@ package user
import (
"context"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
"github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
type request struct {
messageID string
addressID string
addrKR *crypto.KeyRing
}
@ -54,8 +58,38 @@ func newBuilder(f fetcher, msgWorkers, attWorkers int) *pool.Pool[request, *imap
return nil, err
}
return getMessageCreatedUpdate(msg, literal)
return newMessageCreatedUpdate(msg, literal)
})
return msgPool
}
func newMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
flags := imap.NewFlagSet()
if !message.Unread {
flags = flags.Add(imap.FlagSeen)
}
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
flags = flags.Add(imap.FlagFlagged)
}
imapMessage := imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}
return &imap.MessageCreated{
Message: imapMessage,
Literal: literal,
LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)),
ParsedMessage: parsedMessage,
}, nil
}

View File

@ -8,5 +8,5 @@ var (
ErrNotSupported = errors.New("not supported")
ErrInvalidReturnPath = errors.New("invalid return path")
ErrInvalidRecipient = errors.New("invalid recipient")
ErrMissingAddressKey = errors.New("missing address key")
ErrMissingAddrKey = errors.New("missing address key")
)

View File

@ -2,43 +2,44 @@ package user
import (
"context"
"fmt"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
// handleAPIEvent handles the given liteapi.Event.
func (user *User) handleAPIEvent(event liteapi.Event) error {
func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error {
if event.User != nil {
if err := user.handleUserEvent(*event.User); err != nil {
if err := user.handleUserEvent(ctx, *event.User); err != nil {
return err
}
}
if len(event.Addresses) > 0 {
if err := user.handleAddressEvents(event.Addresses); err != nil {
if err := user.handleAddressEvents(ctx, event.Addresses); err != nil {
return err
}
}
if event.MailSettings != nil {
if err := user.handleMailSettingsEvent(*event.MailSettings); err != nil {
if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil {
return err
}
}
if len(event.Labels) > 0 {
if err := user.handleLabelEvents(event.Labels); err != nil {
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
return err
}
}
if len(event.Messages) > 0 {
if err := user.handleMessageEvents(event.Messages); err != nil {
if err := user.handleMessageEvents(ctx, event.Messages); err != nil {
return err
}
}
@ -47,7 +48,7 @@ func (user *User) handleAPIEvent(event liteapi.Event) error {
}
// handleUserEvent handles the given user event.
func (user *User) handleUserEvent(userEvent liteapi.User) error {
func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil)
if err != nil {
return err
@ -57,49 +58,31 @@ func (user *User) handleUserEvent(userEvent liteapi.User) error {
user.userKR = userKR
user.notifyCh <- events.UserChanged{
user.eventCh.Enqueue(events.UserChanged{
UserID: user.ID(),
}
})
return nil
}
// handleAddressEvents handles the given address events.
// TODO: If split address mode, need to signal back to bridge to update the addresses!
func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) error {
func (user *User) handleAddressEvents(ctx context.Context, addressEvents []liteapi.AddressEvent) error {
for _, event := range addressEvents {
switch event.Action {
case liteapi.EventDelete:
address, err := user.deleteAddress(event.ID)
if err != nil {
return err
}
// TODO: This is not the same as addressChangedLogout event!
// That was only relevant in split mode. This is used differently now.
user.notifyCh <- events.UserAddressDeleted{
UserID: user.ID(),
Address: address.Email,
}
case liteapi.EventCreate:
if err := user.createAddress(event.Address); err != nil {
return err
}
user.notifyCh <- events.UserAddressCreated{
UserID: user.ID(),
Address: event.Address.Email,
if err := user.handleCreateAddressEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create address event: %w", err)
}
case liteapi.EventUpdate:
if err := user.updateAddress(event.Address); err != nil {
return err
if err := user.handleUpdateAddressEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update address event: %w", err)
}
user.notifyCh <- events.UserAddressChanged{
UserID: user.ID(),
Address: event.Address.Email,
case liteapi.EventDelete:
if err := user.handleDeleteAddressEvent(ctx, event); err != nil {
return fmt.Errorf("failed to delete address: %w", err)
}
}
}
@ -107,111 +90,189 @@ func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) erro
return nil
}
// createAddress creates the given address.
func (user *User) createAddress(address liteapi.Address) error {
addrKR, err := address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
if err != nil {
return err
return fmt.Errorf("failed to unlock address keys: %w", err)
}
if user.imapConn != nil {
user.imapConn.addAddress(address.Email)
user.apiAddrs.insert(event.Address)
user.addrKRs[event.Address.ID] = addrKR
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
if err := user.syncLabels(ctx, event.Address.ID); err != nil {
return fmt.Errorf("failed to sync labels to new address: %w", err)
}
}
user.addresses = append(user.addresses, address)
user.addrKRs[address.ID] = addrKR
user.eventCh.Enqueue(events.UserAddressCreated{
UserID: user.ID(),
AddressID: event.Address.ID,
Email: event.Address.Email,
})
return nil
}
// updateAddress updates the given address.
func (user *User) updateAddress(address liteapi.Address) error {
if _, err := user.deleteAddress(address.ID); err != nil {
return err
func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
if err != nil {
return fmt.Errorf("failed to unlock address keys: %w", err)
}
return user.createAddress(address)
}
user.apiAddrs.insert(event.Address)
// deleteAddress deletes the given address.
func (user *User) deleteAddress(addressID string) (liteapi.Address, error) {
idx := xslices.IndexFunc(user.addresses, func(address liteapi.Address) bool {
return address.ID == addressID
user.addrKRs[event.Address.ID] = addrKR
user.eventCh.Enqueue(events.UserAddressUpdated{
UserID: user.ID(),
AddressID: event.Address.ID,
Email: event.Address.Email,
})
if idx < 0 {
return liteapi.Address{}, ErrNoSuchAddress
return nil
}
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
email := user.apiAddrs.delete(event.ID)
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].Close()
delete(user.updateCh, event.ID)
}
if user.imapConn != nil {
user.imapConn.remAddress(user.addresses[idx].Email)
}
user.eventCh.Enqueue(events.UserAddressDeleted{
UserID: user.ID(),
AddressID: event.ID,
Email: email,
})
var address liteapi.Address
address, user.addresses = user.addresses[idx], append(user.addresses[:idx], user.addresses[idx+1:]...)
delete(user.addrKRs, addressID)
return address, nil
return nil
}
// handleMailSettingsEvent handles the given mail settings event.
func (user *User) handleMailSettingsEvent(mailSettingsEvent liteapi.MailSettings) error {
func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
user.settings = mailSettingsEvent
return nil
}
// handleLabelEvents handles the given label events.
func (user *User) handleLabelEvents(labelEvents []liteapi.LabelEvent) error {
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
for _, event := range labelEvents {
switch event.Action {
case liteapi.EventDelete:
user.updateCh <- imap.NewMailboxDeleted(imap.LabelID(event.ID))
case liteapi.EventCreate:
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label))
if err := user.handleCreateLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create label event: %w", err)
}
case liteapi.EventUpdate, liteapi.EventUpdateFlags:
user.updateCh <- imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label))
if err := user.handleUpdateLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update label event: %w", err)
}
case liteapi.EventDelete:
if err := user.handleDeleteLabelEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle delete label event: %w", err)
}
}
}
return nil
}
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
}
return nil
}
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
}
return nil
}
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
for _, updateCh := range user.updateCh {
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
}
return nil
}
// handleMessageEvents handles the given message events.
func (user *User) handleMessageEvents(messageEvents []liteapi.MessageEvent) error {
ctx, cancel := context.WithCancel(context.Background())
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
for _, event := range messageEvents {
switch event.Action {
case liteapi.EventDelete:
return ErrNotImplemented
case liteapi.EventCreate:
messages, err := user.builder.ProcessAll(ctx, []request{{event.ID, user.addrKRs[event.Message.AddressID]}})
if err != nil {
return err
if err := user.handleCreateMessageEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle create message event: %w", err)
}
user.updateCh <- imap.NewMessagesCreated(maps.Values(messages)...)
case liteapi.EventUpdate, liteapi.EventUpdateFlags:
user.updateCh <- imap.NewMessageLabelsUpdated(
imap.MessageID(event.ID),
imapLabelIDs(filterLabelIDs(event.Message.LabelIDs)),
bool(!event.Message.Unread),
slices.Contains(event.Message.LabelIDs, liteapi.StarredLabel),
)
if err := user.handleUpdateMessageEvent(ctx, event); err != nil {
return fmt.Errorf("failed to handle update message event: %w", err)
}
case liteapi.EventDelete:
return ErrNotImplemented
}
}
return nil
}
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
var addressID string
if user.GetAddressMode() == vault.CombinedMode {
addressID = user.apiAddrs.primary()
} else {
addressID = event.Message.AddressID
}
message, err := user.builder.ProcessOne(ctx, request{
messageID: event.ID,
addressID: addressID,
addrKR: user.addrKRs[event.Message.AddressID],
})
if err != nil {
return err
}
user.updateCh[addressID].Enqueue(imap.NewMessagesCreated(message))
return nil
}
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
update := imap.NewMessageLabelsUpdated(
imap.MessageID(event.ID),
mapTo[string, imap.LabelID](xslices.Filter(event.Message.LabelIDs, wantLabelID)),
event.Message.Seen(),
event.Message.Starred(),
)
if user.GetAddressMode() == vault.CombinedMode {
user.updateCh[user.apiAddrs.primary()].Enqueue(update)
} else {
user.updateCh[event.Message.AddressID].Enqueue(update)
}
return nil
}
func getMailboxName(label liteapi.Label) []string {
var name []string

76
internal/user/flusher.go Normal file
View File

@ -0,0 +1,76 @@
package user
import (
"sync"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
)
type flusher struct {
userID string
updateCh *queue.QueuedChannel[imap.Update]
eventCh *queue.QueuedChannel[events.Event]
updates []*imap.MessageCreated
maxChunkSize int
curChunkSize int
count int
total int
start time.Time
pushLock sync.Mutex
}
func newFlusher(
userID string,
updateCh *queue.QueuedChannel[imap.Update],
eventCh *queue.QueuedChannel[events.Event],
total, maxChunkSize int,
) *flusher {
return &flusher{
userID: userID,
updateCh: updateCh,
eventCh: eventCh,
maxChunkSize: maxChunkSize,
total: total,
start: time.Now(),
}
}
func (f *flusher) push(update *imap.MessageCreated) {
f.pushLock.Lock()
defer f.pushLock.Unlock()
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) == 0 {
return
}
f.count += len(f.updates)
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
f.eventCh.Enqueue(newSyncProgress(f.userID, f.count, f.total, f.start))
f.updates = nil
f.curChunkSize = 0
}
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
return events.SyncProgress{
UserID: userID,
Progress: float64(count) / float64(total),
Elapsed: time.Since(start),
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
}
}

View File

@ -25,11 +25,12 @@ const (
)
type imapConnector struct {
addrID string
client *liteapi.Client
updateCh <-chan imap.Update
addresses []string
password string
emails []string
password string
flags, permFlags, attrs imap.FlagSet
}
@ -37,15 +38,15 @@ type imapConnector struct {
func newIMAPConnector(
client *liteapi.Client,
updateCh <-chan imap.Update,
addresses []string,
password string,
emails ...string,
) *imapConnector {
return &imapConnector{
client: client,
updateCh: updateCh,
addresses: addresses,
password: password,
emails: emails,
password: password,
flags: defaultFlags,
permFlags: defaultPermanentFlags,
@ -59,7 +60,7 @@ func (conn *imapConnector) Authorize(username string, password string) bool {
return false
}
return xslices.IndexFunc(conn.addresses, func(address string) bool {
return xslices.IndexFunc(conn.emails, func(address string) bool {
return strings.EqualFold(address, username)
}) >= 0
}
@ -187,7 +188,7 @@ func (conn *imapConnector) GetMessage(ctx context.Context, messageID imap.Messag
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}, imapLabelIDs(message.LabelIDs), nil
}, mapTo[string, imap.LabelID](message.LabelIDs), nil
}
// CreateMessage creates a new message on the remote.
@ -204,21 +205,21 @@ func (conn *imapConnector) CreateMessage(
// LabelMessages labels the given messages with the given label ID.
func (conn *imapConnector) LabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelID))
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
}
// UnlabelMessages unlabels the given messages with the given label ID.
func (conn *imapConnector) UnlabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelID))
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
}
// MoveMessages removes the given messages from one label and adds them to the other label.
func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.LabelID, labelToID imap.LabelID) error {
if err := conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelToID)); err != nil {
if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
return fmt.Errorf("labeling messages: %w", err)
}
if err := conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelFromID)); err != nil {
if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
return fmt.Errorf("unlabeling messages: %w", err)
}
@ -228,18 +229,18 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
// MarkMessagesSeen sets the seen value of the given messages.
func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error {
if seen {
return conn.client.MarkMessagesRead(ctx, strMessageIDs(messageIDs)...)
return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
} else {
return conn.client.MarkMessagesUnread(ctx, strMessageIDs(messageIDs)...)
return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
}
}
// MarkMessagesFlagged sets the flagged value of the given messages.
func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error {
if flagged {
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel)
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
} else {
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel)
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
}
}
@ -249,45 +250,17 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update {
return conn.updateCh
}
// Close the connector when it will no longer be used and all resources should be closed/released.
func (conn *imapConnector) Close(ctx context.Context) error {
// GetUIDValidity returns the default UID validity for this user.
func (conn *imapConnector) GetUIDValidity() imap.UID {
return imap.UID(1)
}
// SetUIDValidity sets the default UID validity for this user.
func (conn *imapConnector) SetUIDValidity(uidValidity imap.UID) error {
return nil
}
func (conn *imapConnector) addAddress(address string) {
conn.addresses = append(conn.addresses, address)
}
func (conn *imapConnector) remAddress(address string) {
idx := slices.Index(conn.addresses, address)
if idx < 0 {
return
}
conn.addresses = append(conn.addresses[:idx], conn.addresses[idx+1:]...)
}
func strLabelIDs(imapLabelIDs []imap.LabelID) []string {
return xslices.Map(imapLabelIDs, func(labelID imap.LabelID) string {
return string(labelID)
})
}
func imapLabelIDs(labelIDs []string) []imap.LabelID {
return xslices.Map(labelIDs, func(labelID string) imap.LabelID {
return imap.LabelID(labelID)
})
}
func strMessageIDs(imapMessageIDs []imap.MessageID) []string {
return xslices.Map(imapMessageIDs, func(messageID imap.MessageID) string {
return string(messageID)
})
}
func imapMessageIDs(messageIDs []string) []imap.MessageID {
return xslices.Map(messageIDs, func(messageID string) imap.MessageID {
return imap.MessageID(messageID)
})
// Close the connector will no longer be used and all resources should be closed/released.
func (conn *imapConnector) Close(ctx context.Context) error {
return nil
}

89
internal/user/map.go Normal file
View File

@ -0,0 +1,89 @@
package user
import (
"github.com/bradenaw/juniper/xslices"
"golang.org/x/exp/slices"
)
type ordMap[Key comparable, Val, Data any] struct {
data map[Key]Data
order []Key
toKey func(Data) Key
toVal func(Data) Val
isLess func(Data, Data) bool
}
func newOrdMap[Key comparable, Val, Data any](
key func(Data) Key,
value func(Data) Val,
less func(Data, Data) bool,
data ...Data,
) ordMap[Key, Val, Data] {
m := ordMap[Key, Val, Data]{
data: make(map[Key]Data),
toKey: key,
toVal: value,
isLess: less,
}
for _, d := range data {
m.insert(d)
}
return m
}
func (set *ordMap[Key, Val, Data]) insert(data Data) {
if _, ok := set.data[set.toKey(data)]; ok {
set.delete(set.toKey(data))
}
set.data[set.toKey(data)] = data
set.order = append(set.order, set.toKey(data))
slices.SortFunc(set.order, func(a, b Key) bool {
return set.isLess(set.data[a], set.data[b])
})
}
func (set *ordMap[Key, Val, Data]) delete(key Key) Val {
data, ok := set.data[key]
if !ok {
return *new(Val)
}
delete(set.data, key)
set.order = xslices.Filter(set.order, func(otherKey Key) bool {
return otherKey != key
})
return set.toVal(data)
}
func (set *ordMap[Key, Val, Data]) get(key Key) Val {
return set.toVal(set.data[key])
}
func (set *ordMap[Key, Val, Data]) keys() []Key {
return set.order
}
func (set *ordMap[Key, Val, Data]) values() []Val {
return xslices.Map(set.order, func(key Key) Val {
return set.toVal(set.data[key])
})
}
func (set *ordMap[Key, Val, Data]) toMap() map[Key]Val {
m := make(map[Key]Val)
for _, key := range set.order {
m[key] = set.toVal(set.data[key])
}
return m
}

48
internal/user/map_test.go Normal file
View File

@ -0,0 +1,48 @@
package user
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestMap(t *testing.T) {
type Key int
type Value string
type Data struct {
key Key
value Value
}
m := newOrdMap(
func(d Data) Key { return d.key },
func(d Data) Value { return d.value },
func(a, b Data) bool { return a.key < b.key },
Data{key: 1, value: "a"},
Data{key: 2, value: "b"},
Data{key: 3, value: "c"},
)
// Insert some new data.
m.insert(Data{key: 4, value: "d"})
m.insert(Data{key: 5, value: "e"})
// Delete some data.
require.Equal(t, Value("c"), m.delete(3))
require.Equal(t, Value("a"), m.delete(1))
require.Equal(t, Value("e"), m.delete(5))
// Check the remaining keys and values are correct.
require.Equal(t, []Key{2, 4}, m.keys())
require.Equal(t, []Value{"b", "d"}, m.values())
// Overwrite some data.
m.insert(Data{key: 2, value: "two"})
m.insert(Data{key: 4, value: "four"})
// Check the remaining keys and values are correct.
require.Equal(t, []Key{2, 4}, m.keys())
require.Equal(t, []Value{"two", "four"}, m.values())
}

View File

@ -20,12 +20,14 @@ import (
)
type smtpSession struct {
client *liteapi.Client
username string
addresses []liteapi.Address
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
settings liteapi.MailSettings
client *liteapi.Client
username string
emails map[string]string
settings liteapi.MailSettings
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
from string
to map[string]struct{}
@ -34,18 +36,20 @@ type smtpSession struct {
func newSMTPSession(
client *liteapi.Client,
username string,
addresses []liteapi.Address,
addresses map[string]string,
settings liteapi.MailSettings,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
settings liteapi.MailSettings,
) *smtpSession {
return &smtpSession{
client: client,
username: username,
addresses: addresses,
userKR: userKR,
addrKRs: addrKRs,
settings: settings,
client: client,
username: username,
emails: addresses,
settings: settings,
userKR: userKR,
addrKRs: addrKRs,
from: "",
to: make(map[string]struct{}),
@ -86,15 +90,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
return ErrNotImplemented
}
idx := xslices.IndexFunc(session.addresses, func(address liteapi.Address) bool {
return strings.EqualFold(address.Email, from)
})
if idx < 0 {
return ErrInvalidReturnPath
for addrID, email := range session.emails {
if strings.EqualFold(from, email) {
session.from = addrID
}
}
session.from = session.addresses[idx].ID
if session.from == "" {
return ErrInvalidReturnPath
}
return nil
}
@ -129,10 +133,10 @@ func (session *smtpSession) Data(r io.Reader) error {
addrKR, ok := session.addrKRs[session.from]
if !ok {
return ErrMissingAddressKey
return ErrMissingAddrKey
}
addrKR, err := addrKR.FirstKey()
addrKey, err := addrKR.FirstKey()
if err != nil {
return fmt.Errorf("failed to get first key: %w", err)
}
@ -143,7 +147,7 @@ func (session *smtpSession) Data(r io.Reader) error {
}
if session.settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
key, err := addrKR.GetKey(0)
key, err := addrKey.GetKey(0)
if err != nil {
return fmt.Errorf("failed to get user public key: %w", err)
}
@ -153,7 +157,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to get user public key: %w", err)
}
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKey.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
}
message, err := message.ParseWithParser(parser)
@ -161,7 +165,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to parse message: %w", err)
}
draft, attKeys, err := session.createDraft(ctx, addrKR, message)
draft, attKeys, err := session.createDraft(ctx, addrKey, message)
if err != nil {
return fmt.Errorf("failed to create draft: %w", err)
}
@ -171,7 +175,7 @@ func (session *smtpSession) Data(r io.Reader) error {
return fmt.Errorf("failed to get recipients: %w", err)
}
req, err := createSendReq(addrKR, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
req, err := createSendReq(addrKey, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
if err != nil {
return fmt.Errorf("failed to create packages: %w", err)
}

View File

@ -4,57 +4,34 @@ import (
"context"
"fmt"
"strings"
"sync"
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/slices"
)
const chunkSize = 1 << 20
func (user *User) sync(ctx context.Context) error {
user.notifyCh <- events.SyncStarted{
UserID: user.ID(),
}
if err := user.syncLabels(ctx); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.syncMessages(ctx); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
user.notifyCh <- events.SyncFinished{
UserID: user.ID(),
}
if err := user.vault.SetSync(true); err != nil {
return fmt.Errorf("failed to update sync status: %w", err)
}
return nil
}
func (user *User) syncLabels(ctx context.Context) error {
func (user *User) syncLabels(ctx context.Context, addrIDs ...string) error {
// Sync the system folders.
system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem)
if err != nil {
return err
}
for _, label := range system {
user.updateCh <- newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name)
for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) {
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
}
}
// Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} {
user.updateCh <- newPlaceHolderMailboxCreatedUpdate(prefix)
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
}
}
// Sync the API folders.
@ -64,7 +41,9 @@ func (user *User) syncLabels(ctx context.Context) error {
}
for _, folder := range folders {
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path})
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}))
}
}
// Sync the API labels.
@ -74,7 +53,9 @@ func (user *User) syncLabels(ctx context.Context) error {
}
for _, label := range labels {
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path})
for _, addrID := range addrIDs {
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}))
}
}
return nil
@ -84,27 +65,53 @@ func (user *User) syncMessages(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
// Determine which messages to sync.
// TODO: This needs to be done better using the new API route to retrieve just the message IDs.
metadata, err := user.client.GetAllMessageMetadata(ctx)
if err != nil {
return err
}
// If in split mode, we need to send each message to a different IMAP connector.
isSplitMode := user.vault.AddressMode() == vault.SplitMode
// Collect the build requests -- we need:
// - the message ID to build,
// - the keyring to decrypt the message,
// - and the address to send the message to (for split mode).
requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request {
var addressID string
if isSplitMode {
addressID = metadata.AddressID
} else {
addressID = user.apiAddrs.primary()
}
return request{
messageID: metadata.ID,
addressID: addressID,
addrKR: user.addrKRs[metadata.AddressID],
}
})
flusher := newFlusher(user.ID(), user.updateCh, user.notifyCh, len(metadata), chunkSize)
defer flusher.flush()
// Create the flushers, one per update channel.
flushers := make(map[string]*flusher)
for addrID, updateCh := range user.updateCh {
flusher := newFlusher(user.ID(), updateCh, user.eventCh, len(requests), chunkSize)
defer flusher.flush()
flushers[addrID] = flusher
}
// Build the messages and send them to the correct flusher.
if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error {
if err != nil {
return fmt.Errorf("failed to build message %s: %w", req.messageID, err)
}
flusher.push(res)
flushers[req.addressID].push(res)
return nil
}); err != nil {
@ -114,95 +121,15 @@ func (user *User) syncMessages(ctx context.Context) error {
return nil
}
type flusher struct {
userID string
func (user *User) syncWait() {
for _, updateCh := range user.updateCh {
waiter := imap.NewNoop()
defer waiter.Wait()
updates []*imap.MessageCreated
updateCh chan<- imap.Update
notifyCh chan<- events.Event
maxChunkSize int
curChunkSize int
count int
total int
start time.Time
pushLock sync.Mutex
}
func newFlusher(userID string, updateCh chan<- imap.Update, notifyCh chan<- events.Event, total, maxChunkSize int) *flusher {
return &flusher{
userID: userID,
updateCh: updateCh,
notifyCh: notifyCh,
maxChunkSize: maxChunkSize,
total: total,
start: time.Now(),
updateCh.Enqueue(waiter)
}
}
func (f *flusher) push(update *imap.MessageCreated) {
f.pushLock.Lock()
defer f.pushLock.Unlock()
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) == 0 {
return
}
f.count += len(f.updates)
f.updateCh <- imap.NewMessagesCreated(f.updates...)
f.notifyCh <- newSyncProgress(f.userID, f.count, f.total, f.start)
f.updates = nil
f.curChunkSize = 0
}
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
return events.SyncProgress{
UserID: userID,
Progress: float64(count) / float64(total),
Elapsed: time.Since(start),
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
}
}
func getMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
parsedMessage, err := imap.NewParsedMessage(literal)
if err != nil {
return nil, err
}
flags := imap.NewFlagSet()
if !message.Unread {
flags = flags.Add(imap.FlagSeen)
}
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
flags = flags.Add(imap.FlagFlagged)
}
imapMessage := imap.Message{
ID: imap.MessageID(message.ID),
Flags: flags,
Date: time.Unix(message.Time, 0),
}
return &imap.MessageCreated{
Message: imapMessage,
Literal: literal,
LabelIDs: imapLabelIDs(filterLabelIDs(message.LabelIDs)),
ParsedMessage: parsedMessage,
}, nil
}
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
if strings.EqualFold(labelName, imap.Inbox) {
labelName = imap.Inbox
@ -237,18 +164,12 @@ func newMailboxCreatedUpdate(labelID imap.LabelID, labelName []string) *imap.Mai
})
}
func filterLabelIDs(labelIDs []string) []string {
var filteredLabelIDs []string
func wantLabelID(labelID string) bool {
switch labelID {
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
return false
for _, labelID := range labelIDs {
switch labelID {
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
// ... skip ...
default:
filteredLabelIDs = append(filteredLabelIDs, labelID)
}
default:
return true
}
return filteredLabelIDs
}

13
internal/user/types.go Normal file
View File

@ -0,0 +1,13 @@
package user
import "reflect"
func mapTo[From, To any](from []From) []To {
to := make([]To, 0, len(from))
for _, from := range from {
to = append(to, reflect.ValueOf(from).Convert(reflect.TypeOf(to).Elem()).Interface().(To))
}
return to
}

View File

@ -0,0 +1,20 @@
package user
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestToType(t *testing.T) {
type myString string
// Slices of different types are not equal.
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
// But converting them to the same type makes them equal.
require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
// The conversion can happen in the other direction too.
require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
}

View File

@ -2,19 +2,22 @@ package user
import (
"context"
"fmt"
"runtime"
"time"
"github.com/ProtonMail/gluon"
"github.com/ProtonMail/gluon/connector"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gopenpgp/v2/crypto"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/bradenaw/juniper/xslices"
"github.com/emersion/go-smtp"
"github.com/sirupsen/logrus"
"gitlab.protontech.ch/go/liteapi"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
@ -23,40 +26,38 @@ var (
DefaultEventJitter = 20 * time.Second
)
// TODO: Is it bad to store the key pass in the user? Any worse than storing private keys?
type User struct {
vault *vault.User
client *liteapi.Client
builder *pool.Pool[request, *imap.MessageCreated]
eventCh *queue.QueuedChannel[events.Event]
apiUser liteapi.User
addresses []liteapi.Address
settings liteapi.MailSettings
notifyCh chan events.Event
updateCh chan imap.Update
apiUser liteapi.User
apiAddrs *addrList
userKR *crypto.KeyRing
addrKRs map[string]*crypto.KeyRing
imapConn *imapConnector
settings liteapi.MailSettings
updateCh map[string]*queue.QueuedChannel[imap.Update]
syncWG gluon.WaitGroup
}
func New(
ctx context.Context,
vault *vault.User,
encVault *vault.User,
client *liteapi.Client,
apiUser liteapi.User,
apiAddrs []liteapi.Address,
userKR *crypto.KeyRing,
addrKRs map[string]*crypto.KeyRing,
) (*User, error) {
if vault.EventID() == "" {
if encVault.EventID() == "" {
eventID, err := client.GetLatestEventID(ctx)
if err != nil {
return nil, err
}
if err := vault.SetEventID(eventID); err != nil {
if err := encVault.SetEventID(eventID); err != nil {
return nil, err
}
}
@ -67,19 +68,29 @@ func New(
}
user := &User{
apiUser: apiUser,
addresses: apiAddrs,
settings: settings,
vault: vault,
vault: encVault,
client: client,
builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()),
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
notifyCh: make(chan events.Event),
updateCh: make(chan imap.Update),
apiUser: apiUser,
apiAddrs: newAddrList(apiAddrs),
userKR: userKR,
addrKRs: addrKRs,
userKR: userKR,
addrKRs: addrKRs,
settings: settings,
updateCh: make(map[string]*queue.QueuedChannel[imap.Update]),
}
// Initialize update channels for each of the user's addresses.
for _, addrID := range user.apiAddrs.addrIDs() {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
// If in combined mode, we only need one update channel.
if encVault.AddressMode() == vault.CombinedMode {
break
}
}
// When we receive an auth object, we update it in the store.
@ -93,111 +104,234 @@ func New(
// When we are deauthorized, we send a deauth event to the notify channel.
// Bridge will catch this and log the user out.
client.AddDeauthHandler(func() {
user.notifyCh <- events.UserDeauth{
user.eventCh.Enqueue(events.UserDeauth{
UserID: user.ID(),
}
})
})
// When we receive an API event, we attempt to handle it. If successful, we send the event to the event channel.
// When we receive an API event, we attempt to handle it.
// If successful, we update the event ID in the vault.
go func() {
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, vault.EventID()).Subscribe() {
if err := user.handleAPIEvent(event); err != nil {
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() {
if err := user.handleAPIEvent(ctx, event); err != nil {
logrus.WithError(err).Error("Failed to handle event")
} else {
if err := user.vault.SetEventID(event.EventID); err != nil {
logrus.WithError(err).Error("Failed to update event ID")
}
} else if err := user.vault.SetEventID(event.EventID); err != nil {
logrus.WithError(err).Error("Failed to update event ID")
}
}
}()
// TODO: Use a proper sync manager! (if partial sync, pickup from where we last stopped)
if !vault.HasSync() {
go user.sync(context.Background())
}
return user, nil
}
// ID returns the user's ID.
func (user *User) ID() string {
return user.apiUser.ID
}
// Name returns the user's username.
func (user *User) Name() string {
return user.apiUser.Name
}
// Match matches the given query against the user's username and email addresses.
func (user *User) Match(query string) bool {
if query == user.Name() {
if query == user.apiUser.Name {
return true
}
return slices.Contains(user.Addresses(), query)
return slices.Contains(user.apiAddrs.emails(), query)
}
func (user *User) Addresses() []string {
return xslices.Map(
sort(user.addresses, func(a, b liteapi.Address) bool {
return a.Order < b.Order
}),
func(address liteapi.Address) string {
return address.Email
},
)
// Emails returns all the user's email addresses.
func (user *User) Emails() []string {
return user.apiAddrs.emails()
}
func (user *User) GluonID() string {
return user.vault.GluonID()
// GetAddressMode returns the user's current address mode.
func (user *User) GetAddressMode() vault.AddressMode {
return user.vault.AddressMode()
}
// SetAddressMode sets the user's address mode.
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
for _, updateCh := range user.updateCh {
updateCh.Close()
}
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
for _, addrID := range user.apiAddrs.addrIDs() {
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
if mode == vault.CombinedMode {
break
}
}
if err := user.vault.SetAddressMode(mode); err != nil {
return fmt.Errorf("failed to set address mode: %w", err)
}
return nil
}
// GetGluonIDs returns the users gluon IDs.
func (user *User) GetGluonIDs() map[string]string {
return user.vault.GetGluonIDs()
}
// GetGluonID returns the gluon ID for the given address, if present.
func (user *User) GetGluonID(addrID string) (string, bool) {
gluonID, ok := user.vault.GetGluonIDs()[addrID]
if !ok {
return "", false
}
return gluonID, true
}
// SetGluonID sets the gluon ID for the given address.
func (user *User) SetGluonID(addrID, gluonID string) error {
return user.vault.SetGluonID(addrID, gluonID)
}
// GluonKey returns the user's gluon key from the vault.
func (user *User) GluonKey() []byte {
return user.vault.GluonKey()
}
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
func (user *User) BridgePass() string {
return user.vault.BridgePass()
}
// UsedSpace returns the total space used by the user on the API.
func (user *User) UsedSpace() int {
return user.apiUser.UsedSpace
}
// MaxSpace returns the amount of space the user can use on the API.
func (user *User) MaxSpace() int {
return user.apiUser.MaxSpace
}
// GetNotifyCh returns a channel which notifies of events happening to the user (such as deauth, address change)
func (user *User) GetNotifyCh() <-chan events.Event {
return user.notifyCh
// HasSync returns whether the user has finished syncing.
func (user *User) HasSync() bool {
return user.vault.HasSync()
}
func (user *User) NewGluonConnector(ctx context.Context) (connector.Connector, error) {
if user.imapConn != nil {
if err := user.imapConn.Close(ctx); err != nil {
return nil, err
// AbortSync aborts any ongoing sync.
// TODO: This should abort the sync rather than just waiting.
// Should probably be done automatically when one of the user's IMAP connectors is closed.
func (user *User) AbortSync(ctx context.Context) error {
user.syncWG.Wait()
return nil
}
// DoSync performs a sync for the user.
func (user *User) DoSync(ctx context.Context) <-chan error {
errCh := queue.NewQueuedChannel[error](0, 0)
user.syncWG.Go(func() {
defer errCh.Close()
user.eventCh.Enqueue(events.SyncStarted{
UserID: user.ID(),
})
errCh.Enqueue(func() error {
if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil {
return fmt.Errorf("failed to sync labels: %w", err)
}
if err := user.syncMessages(ctx); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
user.syncWait()
if err := user.vault.SetSync(true); err != nil {
return fmt.Errorf("failed to set sync status: %w", err)
}
return nil
}())
user.eventCh.Enqueue(events.SyncFinished{
UserID: user.ID(),
})
})
return errCh.GetChannel()
}
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change)
func (user *User) GetEventCh() <-chan events.Event {
return user.eventCh.GetChannel()
}
// NewIMAPConnector returns an IMAP connector for the given address.
// If not in split mode, this function returns an error.
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
var emails []string
switch user.vault.AddressMode() {
case vault.CombinedMode:
if addrID != user.apiAddrs.primary() {
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
}
emails = user.apiAddrs.emails()
case vault.SplitMode:
emails = []string{user.apiAddrs.email(addrID)}
}
user.imapConn = newIMAPConnector(user.client, user.updateCh, user.Addresses(), user.vault.BridgePass())
return user.imapConn, nil
return newIMAPConnector(
user.client,
user.updateCh[addrID].GetChannel(),
user.vault.BridgePass(),
emails...,
), nil
}
func (user *User) NewSMTPSession(username string) (smtp.Session, error) {
return newSMTPSession(user.client, username, user.addresses, user.userKR, user.addrKRs, user.settings), nil
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
// In combined mode, this is just the user's primary address.
// In split mode, this is all the user's addresses.
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
imapConn := make(map[string]connector.Connector)
for addrID := range user.updateCh {
conn, err := user.NewIMAPConnector(addrID)
if err != nil {
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
}
imapConn[addrID] = conn
}
return imapConn, nil
}
// NewSMTPSession returns an SMTP session for the user.
func (user *User) NewSMTPSession(username string) smtp.Session {
return newSMTPSession(user.client, username, user.apiAddrs.addrMap(), user.settings, user.userKR, user.addrKRs)
}
// Logout logs the user out from the API.
func (user *User) Logout(ctx context.Context) error {
return user.client.AuthDelete(ctx)
}
// Close closes ongoing connections and cleans up resources.
func (user *User) Close(ctx context.Context) error {
// Close the user's IMAP connectors.
if user.imapConn != nil {
if err := user.imapConn.Close(ctx); err != nil {
return err
}
}
// Wait for ongoing syncs to finish.
user.syncWG.Wait()
// Close the user's message builder.
user.builder.Done()
@ -205,15 +339,13 @@ func (user *User) Close(ctx context.Context) error {
// Close the user's API client.
user.client.Close()
// Close the user's update channels.
for _, updateCh := range user.updateCh {
updateCh.Close()
}
// Close the user's notify channel.
close(user.notifyCh)
user.eventCh.Close()
return nil
}
// sort returns the slice, sorted by the given callback.
func sort[T any](slice []T, less func(a, b T) bool) []T {
slices.SortFunc(slice, less)
return slice
}

162
internal/user/user_test.go Normal file
View File

@ -0,0 +1,162 @@
package user_test
import (
"context"
"testing"
"time"
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
"github.com/ProtonMail/proton-bridge/v2/internal/events"
"github.com/ProtonMail/proton-bridge/v2/internal/user"
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
"github.com/ProtonMail/proton-bridge/v2/tests"
"github.com/bradenaw/juniper/iterator"
"github.com/emersion/go-imap"
"github.com/emersion/go-imap/client"
"github.com/stretchr/testify/require"
"gitlab.protontech.ch/go/liteapi"
"gitlab.protontech.ch/go/liteapi/server"
"gitlab.protontech.ch/go/liteapi/server/account"
)
func init() {
user.DefaultEventPeriod = 100 * time.Millisecond
user.DefaultEventJitter = 0
account.GenerateKey = tests.FastGenerateKey
certs.GenerateCert = tests.FastGenerateCert
}
func TestUser_Data(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
// User's ID should be correct.
require.Equal(t, userID, user.ID())
// User's name should be correct.
require.Equal(t, "username", user.Name())
// User's email should be correct.
require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails())
// By default, user should be in combined mode.
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
// By default, user should have a non-empty bridge password.
require.NotEmpty(t, user.BridgePass())
})
})
}
func TestUser_Sync(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
// Get the user's IMAP connectors.
imapConn, err := user.NewIMAPConnectors()
require.NoError(t, err)
// Pretend to be gluon applying all the updates.
go func() {
for _, imapConn := range imapConn {
for update := range imapConn.GetUpdates() {
update.Done()
}
}
}()
// Trigger a user sync.
errCh := user.DoSync(ctx)
// User starts a sync at startup.
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
// User finishes a sync at startup.
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
// The sync completes without error.
require.NoError(t, <-errCh)
})
})
}
func TestUser_Deauth(t *testing.T) {
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
eventCh := user.GetEventCh()
// Revoke the user's auth token.
require.NoError(t, s.RevokeUser(userID))
// The user should eventually be logged out.
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond)
})
})
}
func withAPI(t *testing.T, ctx context.Context, username, password string, emails []string, fn func(context.Context, *server.Server, string, []string)) {
server := server.New()
defer server.Close()
var addrIDs []string
userID, addrID, err := server.AddUser(username, password, emails[0])
require.NoError(t, err)
addrIDs = append(addrIDs, addrID)
for _, email := range emails[1:] {
addrID, err := server.AddAddress(userID, email, password)
require.NoError(t, err)
addrIDs = append(addrIDs, addrID)
}
fn(ctx, server, userID, addrIDs)
}
func withUser(t *testing.T, ctx context.Context, apiURL, username, password string, fn func(*user.User)) {
c, apiAuth, err := liteapi.New(liteapi.WithHostURL(apiURL)).NewClientWithLogin(ctx, username, password)
require.NoError(t, err)
defer func() { require.NoError(t, c.Close()) }()
apiUser, apiAddrs, userKR, addrKRs, passphrase, err := c.Unlock(ctx, []byte(password))
require.NoError(t, err)
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
require.NoError(t, err)
require.False(t, corrupt)
vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, passphrase)
require.NoError(t, err)
user, err := user.New(ctx, vaultUser, c, apiUser, apiAddrs, userKR, addrKRs)
require.NoError(t, err)
defer func() { require.NoError(t, user.Close(ctx)) }()
fn(user)
}
func withIMAPClient(t *testing.T, addr string, fn func(*client.Client)) {
c, err := client.Dial(addr)
require.NoError(t, err)
defer c.Close()
fn(c)
}
func fetch(t *testing.T, c *client.Client, seqset string, items ...imap.FetchItem) []*imap.Message {
msgCh := make(chan *imap.Message)
go func() {
require.NoError(t, c.Fetch(must(imap.ParseSeqSet(seqset)), items, msgCh))
}()
return iterator.Collect(iterator.Chan(msgCh))
}
func must[T any](v T, err error) T {
if err != nil {
panic(err)
}
return v
}