forked from Silverfish/proton-bridge
GODT-1815: Combined/Split mode
This commit is contained in:
46
internal/user/addresses.go
Normal file
46
internal/user/addresses.go
Normal file
@ -0,0 +1,46 @@
|
||||
package user
|
||||
|
||||
import "gitlab.protontech.ch/go/liteapi"
|
||||
|
||||
type addrList struct {
|
||||
apiAddrs ordMap[string, string, liteapi.Address]
|
||||
}
|
||||
|
||||
func newAddrList(apiAddrs []liteapi.Address) *addrList {
|
||||
return &addrList{
|
||||
apiAddrs: newOrdMap(
|
||||
func(addr liteapi.Address) string { return addr.ID },
|
||||
func(addr liteapi.Address) string { return addr.Email },
|
||||
func(a, b liteapi.Address) bool { return a.Order < b.Order },
|
||||
apiAddrs...,
|
||||
),
|
||||
}
|
||||
}
|
||||
|
||||
func (list *addrList) insert(address liteapi.Address) {
|
||||
list.apiAddrs.insert(address)
|
||||
}
|
||||
|
||||
func (list *addrList) delete(addrID string) string {
|
||||
return list.apiAddrs.delete(addrID)
|
||||
}
|
||||
|
||||
func (list *addrList) primary() string {
|
||||
return list.apiAddrs.keys()[0]
|
||||
}
|
||||
|
||||
func (list *addrList) addrIDs() []string {
|
||||
return list.apiAddrs.keys()
|
||||
}
|
||||
|
||||
func (list *addrList) emails() []string {
|
||||
return list.apiAddrs.values()
|
||||
}
|
||||
|
||||
func (list *addrList) email(addrID string) string {
|
||||
return list.apiAddrs.get(addrID)
|
||||
}
|
||||
|
||||
func (list *addrList) addrMap() map[string]string {
|
||||
return list.apiAddrs.toMap()
|
||||
}
|
||||
@ -2,16 +2,20 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
||||
"github.com/ProtonMail/proton-bridge/v2/pkg/message"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type request struct {
|
||||
messageID string
|
||||
addressID string
|
||||
addrKR *crypto.KeyRing
|
||||
}
|
||||
|
||||
@ -54,8 +58,38 @@ func newBuilder(f fetcher, msgWorkers, attWorkers int) *pool.Pool[request, *imap
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return getMessageCreatedUpdate(msg, literal)
|
||||
return newMessageCreatedUpdate(msg, literal)
|
||||
})
|
||||
|
||||
return msgPool
|
||||
}
|
||||
|
||||
func newMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
|
||||
parsedMessage, err := imap.NewParsedMessage(literal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flags := imap.NewFlagSet()
|
||||
|
||||
if !message.Unread {
|
||||
flags = flags.Add(imap.FlagSeen)
|
||||
}
|
||||
|
||||
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
|
||||
flags = flags.Add(imap.FlagFlagged)
|
||||
}
|
||||
|
||||
imapMessage := imap.Message{
|
||||
ID: imap.MessageID(message.ID),
|
||||
Flags: flags,
|
||||
Date: time.Unix(message.Time, 0),
|
||||
}
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: imapMessage,
|
||||
Literal: literal,
|
||||
LabelIDs: mapTo[string, imap.LabelID](xslices.Filter(message.LabelIDs, wantLabelID)),
|
||||
ParsedMessage: parsedMessage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -8,5 +8,5 @@ var (
|
||||
ErrNotSupported = errors.New("not supported")
|
||||
ErrInvalidReturnPath = errors.New("invalid return path")
|
||||
ErrInvalidRecipient = errors.New("invalid recipient")
|
||||
ErrMissingAddressKey = errors.New("missing address key")
|
||||
ErrMissingAddrKey = errors.New("missing address key")
|
||||
)
|
||||
|
||||
@ -2,43 +2,44 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
// handleAPIEvent handles the given liteapi.Event.
|
||||
func (user *User) handleAPIEvent(event liteapi.Event) error {
|
||||
func (user *User) handleAPIEvent(ctx context.Context, event liteapi.Event) error {
|
||||
if event.User != nil {
|
||||
if err := user.handleUserEvent(*event.User); err != nil {
|
||||
if err := user.handleUserEvent(ctx, *event.User); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Addresses) > 0 {
|
||||
if err := user.handleAddressEvents(event.Addresses); err != nil {
|
||||
if err := user.handleAddressEvents(ctx, event.Addresses); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if event.MailSettings != nil {
|
||||
if err := user.handleMailSettingsEvent(*event.MailSettings); err != nil {
|
||||
if err := user.handleMailSettingsEvent(ctx, *event.MailSettings); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Labels) > 0 {
|
||||
if err := user.handleLabelEvents(event.Labels); err != nil {
|
||||
if err := user.handleLabelEvents(ctx, event.Labels); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
if len(event.Messages) > 0 {
|
||||
if err := user.handleMessageEvents(event.Messages); err != nil {
|
||||
if err := user.handleMessageEvents(ctx, event.Messages); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
@ -47,7 +48,7 @@ func (user *User) handleAPIEvent(event liteapi.Event) error {
|
||||
}
|
||||
|
||||
// handleUserEvent handles the given user event.
|
||||
func (user *User) handleUserEvent(userEvent liteapi.User) error {
|
||||
func (user *User) handleUserEvent(ctx context.Context, userEvent liteapi.User) error {
|
||||
userKR, err := userEvent.Keys.Unlock(user.vault.KeyPass(), nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -57,49 +58,31 @@ func (user *User) handleUserEvent(userEvent liteapi.User) error {
|
||||
|
||||
user.userKR = userKR
|
||||
|
||||
user.notifyCh <- events.UserChanged{
|
||||
user.eventCh.Enqueue(events.UserChanged{
|
||||
UserID: user.ID(),
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleAddressEvents handles the given address events.
|
||||
// TODO: If split address mode, need to signal back to bridge to update the addresses!
|
||||
func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) error {
|
||||
func (user *User) handleAddressEvents(ctx context.Context, addressEvents []liteapi.AddressEvent) error {
|
||||
for _, event := range addressEvents {
|
||||
switch event.Action {
|
||||
case liteapi.EventDelete:
|
||||
address, err := user.deleteAddress(event.ID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// TODO: This is not the same as addressChangedLogout event!
|
||||
// That was only relevant in split mode. This is used differently now.
|
||||
user.notifyCh <- events.UserAddressDeleted{
|
||||
UserID: user.ID(),
|
||||
Address: address.Email,
|
||||
}
|
||||
|
||||
case liteapi.EventCreate:
|
||||
if err := user.createAddress(event.Address); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.notifyCh <- events.UserAddressCreated{
|
||||
UserID: user.ID(),
|
||||
Address: event.Address.Email,
|
||||
if err := user.handleCreateAddressEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle create address event: %w", err)
|
||||
}
|
||||
|
||||
case liteapi.EventUpdate:
|
||||
if err := user.updateAddress(event.Address); err != nil {
|
||||
return err
|
||||
if err := user.handleUpdateAddressEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle update address event: %w", err)
|
||||
}
|
||||
|
||||
user.notifyCh <- events.UserAddressChanged{
|
||||
UserID: user.ID(),
|
||||
Address: event.Address.Email,
|
||||
case liteapi.EventDelete:
|
||||
if err := user.handleDeleteAddressEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to delete address: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -107,111 +90,189 @@ func (user *User) handleAddressEvents(addressEvents []liteapi.AddressEvent) erro
|
||||
return nil
|
||||
}
|
||||
|
||||
// createAddress creates the given address.
|
||||
func (user *User) createAddress(address liteapi.Address) error {
|
||||
addrKR, err := address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
|
||||
func (user *User) handleCreateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
|
||||
if err != nil {
|
||||
return err
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
|
||||
if user.imapConn != nil {
|
||||
user.imapConn.addAddress(address.Email)
|
||||
user.apiAddrs.insert(event.Address)
|
||||
|
||||
user.addrKRs[event.Address.ID] = addrKR
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if err := user.syncLabels(ctx, event.Address.ID); err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
user.addresses = append(user.addresses, address)
|
||||
|
||||
user.addrKRs[address.ID] = addrKR
|
||||
user.eventCh.Enqueue(events.UserAddressCreated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// updateAddress updates the given address.
|
||||
func (user *User) updateAddress(address liteapi.Address) error {
|
||||
if _, err := user.deleteAddress(address.ID); err != nil {
|
||||
return err
|
||||
func (user *User) handleUpdateAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
addrKR, err := event.Address.Keys.Unlock(user.vault.KeyPass(), user.userKR)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to unlock address keys: %w", err)
|
||||
}
|
||||
|
||||
return user.createAddress(address)
|
||||
}
|
||||
user.apiAddrs.insert(event.Address)
|
||||
|
||||
// deleteAddress deletes the given address.
|
||||
func (user *User) deleteAddress(addressID string) (liteapi.Address, error) {
|
||||
idx := xslices.IndexFunc(user.addresses, func(address liteapi.Address) bool {
|
||||
return address.ID == addressID
|
||||
user.addrKRs[event.Address.ID] = addrKR
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
if idx < 0 {
|
||||
return liteapi.Address{}, ErrNoSuchAddress
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteAddressEvent(ctx context.Context, event liteapi.AddressEvent) error {
|
||||
email := user.apiAddrs.delete(event.ID)
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.ID].Close()
|
||||
delete(user.updateCh, event.ID)
|
||||
}
|
||||
|
||||
if user.imapConn != nil {
|
||||
user.imapConn.remAddress(user.addresses[idx].Email)
|
||||
}
|
||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||
UserID: user.ID(),
|
||||
AddressID: event.ID,
|
||||
Email: email,
|
||||
})
|
||||
|
||||
var address liteapi.Address
|
||||
|
||||
address, user.addresses = user.addresses[idx], append(user.addresses[:idx], user.addresses[idx+1:]...)
|
||||
|
||||
delete(user.addrKRs, addressID)
|
||||
|
||||
return address, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMailSettingsEvent handles the given mail settings event.
|
||||
func (user *User) handleMailSettingsEvent(mailSettingsEvent liteapi.MailSettings) error {
|
||||
func (user *User) handleMailSettingsEvent(ctx context.Context, mailSettingsEvent liteapi.MailSettings) error {
|
||||
user.settings = mailSettingsEvent
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleLabelEvents handles the given label events.
|
||||
func (user *User) handleLabelEvents(labelEvents []liteapi.LabelEvent) error {
|
||||
func (user *User) handleLabelEvents(ctx context.Context, labelEvents []liteapi.LabelEvent) error {
|
||||
for _, event := range labelEvents {
|
||||
switch event.Action {
|
||||
case liteapi.EventDelete:
|
||||
user.updateCh <- imap.NewMailboxDeleted(imap.LabelID(event.ID))
|
||||
|
||||
case liteapi.EventCreate:
|
||||
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label))
|
||||
if err := user.handleCreateLabelEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle create label event: %w", err)
|
||||
}
|
||||
|
||||
case liteapi.EventUpdate, liteapi.EventUpdateFlags:
|
||||
user.updateCh <- imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label))
|
||||
if err := user.handleUpdateLabelEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle update label event: %w", err)
|
||||
}
|
||||
|
||||
case liteapi.EventDelete:
|
||||
if err := user.handleDeleteLabelEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle delete label event: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleCreateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Enqueue(newMailboxCreatedUpdate(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Enqueue(imap.NewMailboxUpdated(imap.LabelID(event.ID), getMailboxName(event.Label)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleDeleteLabelEvent(ctx context.Context, event liteapi.LabelEvent) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Enqueue(imap.NewMailboxDeleted(imap.LabelID(event.ID)))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleMessageEvents handles the given message events.
|
||||
func (user *User) handleMessageEvents(messageEvents []liteapi.MessageEvent) error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []liteapi.MessageEvent) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
for _, event := range messageEvents {
|
||||
switch event.Action {
|
||||
case liteapi.EventDelete:
|
||||
return ErrNotImplemented
|
||||
|
||||
case liteapi.EventCreate:
|
||||
messages, err := user.builder.ProcessAll(ctx, []request{{event.ID, user.addrKRs[event.Message.AddressID]}})
|
||||
if err != nil {
|
||||
return err
|
||||
if err := user.handleCreateMessageEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle create message event: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh <- imap.NewMessagesCreated(maps.Values(messages)...)
|
||||
|
||||
case liteapi.EventUpdate, liteapi.EventUpdateFlags:
|
||||
user.updateCh <- imap.NewMessageLabelsUpdated(
|
||||
imap.MessageID(event.ID),
|
||||
imapLabelIDs(filterLabelIDs(event.Message.LabelIDs)),
|
||||
bool(!event.Message.Unread),
|
||||
slices.Contains(event.Message.LabelIDs, liteapi.StarredLabel),
|
||||
)
|
||||
if err := user.handleUpdateMessageEvent(ctx, event); err != nil {
|
||||
return fmt.Errorf("failed to handle update message event: %w", err)
|
||||
}
|
||||
|
||||
case liteapi.EventDelete:
|
||||
return ErrNotImplemented
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleCreateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||
var addressID string
|
||||
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
addressID = user.apiAddrs.primary()
|
||||
} else {
|
||||
addressID = event.Message.AddressID
|
||||
}
|
||||
|
||||
message, err := user.builder.ProcessOne(ctx, request{
|
||||
messageID: event.ID,
|
||||
addressID: addressID,
|
||||
addrKR: user.addrKRs[event.Message.AddressID],
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
user.updateCh[addressID].Enqueue(imap.NewMessagesCreated(message))
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) handleUpdateMessageEvent(ctx context.Context, event liteapi.MessageEvent) error {
|
||||
update := imap.NewMessageLabelsUpdated(
|
||||
imap.MessageID(event.ID),
|
||||
mapTo[string, imap.LabelID](xslices.Filter(event.Message.LabelIDs, wantLabelID)),
|
||||
event.Message.Seen(),
|
||||
event.Message.Starred(),
|
||||
)
|
||||
|
||||
if user.GetAddressMode() == vault.CombinedMode {
|
||||
user.updateCh[user.apiAddrs.primary()].Enqueue(update)
|
||||
} else {
|
||||
user.updateCh[event.Message.AddressID].Enqueue(update)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMailboxName(label liteapi.Label) []string {
|
||||
var name []string
|
||||
|
||||
|
||||
76
internal/user/flusher.go
Normal file
76
internal/user/flusher.go
Normal file
@ -0,0 +1,76 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
)
|
||||
|
||||
type flusher struct {
|
||||
userID string
|
||||
updateCh *queue.QueuedChannel[imap.Update]
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
updates []*imap.MessageCreated
|
||||
maxChunkSize int
|
||||
curChunkSize int
|
||||
|
||||
count int
|
||||
total int
|
||||
start time.Time
|
||||
|
||||
pushLock sync.Mutex
|
||||
}
|
||||
|
||||
func newFlusher(
|
||||
userID string,
|
||||
updateCh *queue.QueuedChannel[imap.Update],
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
total, maxChunkSize int,
|
||||
) *flusher {
|
||||
return &flusher{
|
||||
userID: userID,
|
||||
updateCh: updateCh,
|
||||
eventCh: eventCh,
|
||||
|
||||
maxChunkSize: maxChunkSize,
|
||||
|
||||
total: total,
|
||||
start: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) push(update *imap.MessageCreated) {
|
||||
f.pushLock.Lock()
|
||||
defer f.pushLock.Unlock()
|
||||
|
||||
f.updates = append(f.updates, update)
|
||||
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
|
||||
f.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) flush() {
|
||||
if len(f.updates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
f.count += len(f.updates)
|
||||
f.updateCh.Enqueue(imap.NewMessagesCreated(f.updates...))
|
||||
f.eventCh.Enqueue(newSyncProgress(f.userID, f.count, f.total, f.start))
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
}
|
||||
|
||||
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
|
||||
return events.SyncProgress{
|
||||
UserID: userID,
|
||||
Progress: float64(count) / float64(total),
|
||||
Elapsed: time.Since(start),
|
||||
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
|
||||
}
|
||||
}
|
||||
@ -25,11 +25,12 @@ const (
|
||||
)
|
||||
|
||||
type imapConnector struct {
|
||||
addrID string
|
||||
client *liteapi.Client
|
||||
updateCh <-chan imap.Update
|
||||
|
||||
addresses []string
|
||||
password string
|
||||
emails []string
|
||||
password string
|
||||
|
||||
flags, permFlags, attrs imap.FlagSet
|
||||
}
|
||||
@ -37,15 +38,15 @@ type imapConnector struct {
|
||||
func newIMAPConnector(
|
||||
client *liteapi.Client,
|
||||
updateCh <-chan imap.Update,
|
||||
addresses []string,
|
||||
password string,
|
||||
emails ...string,
|
||||
) *imapConnector {
|
||||
return &imapConnector{
|
||||
client: client,
|
||||
updateCh: updateCh,
|
||||
|
||||
addresses: addresses,
|
||||
password: password,
|
||||
emails: emails,
|
||||
password: password,
|
||||
|
||||
flags: defaultFlags,
|
||||
permFlags: defaultPermanentFlags,
|
||||
@ -59,7 +60,7 @@ func (conn *imapConnector) Authorize(username string, password string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
return xslices.IndexFunc(conn.addresses, func(address string) bool {
|
||||
return xslices.IndexFunc(conn.emails, func(address string) bool {
|
||||
return strings.EqualFold(address, username)
|
||||
}) >= 0
|
||||
}
|
||||
@ -187,7 +188,7 @@ func (conn *imapConnector) GetMessage(ctx context.Context, messageID imap.Messag
|
||||
ID: imap.MessageID(message.ID),
|
||||
Flags: flags,
|
||||
Date: time.Unix(message.Time, 0),
|
||||
}, imapLabelIDs(message.LabelIDs), nil
|
||||
}, mapTo[string, imap.LabelID](message.LabelIDs), nil
|
||||
}
|
||||
|
||||
// CreateMessage creates a new message on the remote.
|
||||
@ -204,21 +205,21 @@ func (conn *imapConnector) CreateMessage(
|
||||
|
||||
// LabelMessages labels the given messages with the given label ID.
|
||||
func (conn *imapConnector) LabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
|
||||
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelID))
|
||||
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
|
||||
}
|
||||
|
||||
// UnlabelMessages unlabels the given messages with the given label ID.
|
||||
func (conn *imapConnector) UnlabelMessages(ctx context.Context, messageIDs []imap.MessageID, labelID imap.LabelID) error {
|
||||
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelID))
|
||||
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelID))
|
||||
}
|
||||
|
||||
// MoveMessages removes the given messages from one label and adds them to the other label.
|
||||
func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.MessageID, labelFromID imap.LabelID, labelToID imap.LabelID) error {
|
||||
if err := conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), string(labelToID)); err != nil {
|
||||
if err := conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelToID)); err != nil {
|
||||
return fmt.Errorf("labeling messages: %w", err)
|
||||
}
|
||||
|
||||
if err := conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), string(labelFromID)); err != nil {
|
||||
if err := conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), string(labelFromID)); err != nil {
|
||||
return fmt.Errorf("unlabeling messages: %w", err)
|
||||
}
|
||||
|
||||
@ -228,18 +229,18 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
|
||||
// MarkMessagesSeen sets the seen value of the given messages.
|
||||
func (conn *imapConnector) MarkMessagesSeen(ctx context.Context, messageIDs []imap.MessageID, seen bool) error {
|
||||
if seen {
|
||||
return conn.client.MarkMessagesRead(ctx, strMessageIDs(messageIDs)...)
|
||||
return conn.client.MarkMessagesRead(ctx, mapTo[imap.MessageID, string](messageIDs)...)
|
||||
} else {
|
||||
return conn.client.MarkMessagesUnread(ctx, strMessageIDs(messageIDs)...)
|
||||
return conn.client.MarkMessagesUnread(ctx, mapTo[imap.MessageID, string](messageIDs)...)
|
||||
}
|
||||
}
|
||||
|
||||
// MarkMessagesFlagged sets the flagged value of the given messages.
|
||||
func (conn *imapConnector) MarkMessagesFlagged(ctx context.Context, messageIDs []imap.MessageID, flagged bool) error {
|
||||
if flagged {
|
||||
return conn.client.LabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel)
|
||||
return conn.client.LabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
|
||||
} else {
|
||||
return conn.client.UnlabelMessages(ctx, strMessageIDs(messageIDs), liteapi.StarredLabel)
|
||||
return conn.client.UnlabelMessages(ctx, mapTo[imap.MessageID, string](messageIDs), liteapi.StarredLabel)
|
||||
}
|
||||
}
|
||||
|
||||
@ -249,45 +250,17 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update {
|
||||
return conn.updateCh
|
||||
}
|
||||
|
||||
// Close the connector when it will no longer be used and all resources should be closed/released.
|
||||
func (conn *imapConnector) Close(ctx context.Context) error {
|
||||
// GetUIDValidity returns the default UID validity for this user.
|
||||
func (conn *imapConnector) GetUIDValidity() imap.UID {
|
||||
return imap.UID(1)
|
||||
}
|
||||
|
||||
// SetUIDValidity sets the default UID validity for this user.
|
||||
func (conn *imapConnector) SetUIDValidity(uidValidity imap.UID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (conn *imapConnector) addAddress(address string) {
|
||||
conn.addresses = append(conn.addresses, address)
|
||||
}
|
||||
|
||||
func (conn *imapConnector) remAddress(address string) {
|
||||
idx := slices.Index(conn.addresses, address)
|
||||
|
||||
if idx < 0 {
|
||||
return
|
||||
}
|
||||
|
||||
conn.addresses = append(conn.addresses[:idx], conn.addresses[idx+1:]...)
|
||||
}
|
||||
|
||||
func strLabelIDs(imapLabelIDs []imap.LabelID) []string {
|
||||
return xslices.Map(imapLabelIDs, func(labelID imap.LabelID) string {
|
||||
return string(labelID)
|
||||
})
|
||||
}
|
||||
|
||||
func imapLabelIDs(labelIDs []string) []imap.LabelID {
|
||||
return xslices.Map(labelIDs, func(labelID string) imap.LabelID {
|
||||
return imap.LabelID(labelID)
|
||||
})
|
||||
}
|
||||
|
||||
func strMessageIDs(imapMessageIDs []imap.MessageID) []string {
|
||||
return xslices.Map(imapMessageIDs, func(messageID imap.MessageID) string {
|
||||
return string(messageID)
|
||||
})
|
||||
}
|
||||
|
||||
func imapMessageIDs(messageIDs []string) []imap.MessageID {
|
||||
return xslices.Map(messageIDs, func(messageID string) imap.MessageID {
|
||||
return imap.MessageID(messageID)
|
||||
})
|
||||
// Close the connector will no longer be used and all resources should be closed/released.
|
||||
func (conn *imapConnector) Close(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
89
internal/user/map.go
Normal file
89
internal/user/map.go
Normal file
@ -0,0 +1,89 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
type ordMap[Key comparable, Val, Data any] struct {
|
||||
data map[Key]Data
|
||||
order []Key
|
||||
|
||||
toKey func(Data) Key
|
||||
toVal func(Data) Val
|
||||
isLess func(Data, Data) bool
|
||||
}
|
||||
|
||||
func newOrdMap[Key comparable, Val, Data any](
|
||||
key func(Data) Key,
|
||||
value func(Data) Val,
|
||||
less func(Data, Data) bool,
|
||||
data ...Data,
|
||||
) ordMap[Key, Val, Data] {
|
||||
m := ordMap[Key, Val, Data]{
|
||||
data: make(map[Key]Data),
|
||||
|
||||
toKey: key,
|
||||
toVal: value,
|
||||
isLess: less,
|
||||
}
|
||||
|
||||
for _, d := range data {
|
||||
m.insert(d)
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) insert(data Data) {
|
||||
if _, ok := set.data[set.toKey(data)]; ok {
|
||||
set.delete(set.toKey(data))
|
||||
}
|
||||
|
||||
set.data[set.toKey(data)] = data
|
||||
|
||||
set.order = append(set.order, set.toKey(data))
|
||||
|
||||
slices.SortFunc(set.order, func(a, b Key) bool {
|
||||
return set.isLess(set.data[a], set.data[b])
|
||||
})
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) delete(key Key) Val {
|
||||
data, ok := set.data[key]
|
||||
if !ok {
|
||||
return *new(Val)
|
||||
}
|
||||
|
||||
delete(set.data, key)
|
||||
|
||||
set.order = xslices.Filter(set.order, func(otherKey Key) bool {
|
||||
return otherKey != key
|
||||
})
|
||||
|
||||
return set.toVal(data)
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) get(key Key) Val {
|
||||
return set.toVal(set.data[key])
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) keys() []Key {
|
||||
return set.order
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) values() []Val {
|
||||
return xslices.Map(set.order, func(key Key) Val {
|
||||
return set.toVal(set.data[key])
|
||||
})
|
||||
}
|
||||
|
||||
func (set *ordMap[Key, Val, Data]) toMap() map[Key]Val {
|
||||
m := make(map[Key]Val)
|
||||
|
||||
for _, key := range set.order {
|
||||
m[key] = set.toVal(set.data[key])
|
||||
}
|
||||
|
||||
return m
|
||||
}
|
||||
48
internal/user/map_test.go
Normal file
48
internal/user/map_test.go
Normal file
@ -0,0 +1,48 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestMap(t *testing.T) {
|
||||
type Key int
|
||||
|
||||
type Value string
|
||||
|
||||
type Data struct {
|
||||
key Key
|
||||
value Value
|
||||
}
|
||||
|
||||
m := newOrdMap(
|
||||
func(d Data) Key { return d.key },
|
||||
func(d Data) Value { return d.value },
|
||||
func(a, b Data) bool { return a.key < b.key },
|
||||
Data{key: 1, value: "a"},
|
||||
Data{key: 2, value: "b"},
|
||||
Data{key: 3, value: "c"},
|
||||
)
|
||||
|
||||
// Insert some new data.
|
||||
m.insert(Data{key: 4, value: "d"})
|
||||
m.insert(Data{key: 5, value: "e"})
|
||||
|
||||
// Delete some data.
|
||||
require.Equal(t, Value("c"), m.delete(3))
|
||||
require.Equal(t, Value("a"), m.delete(1))
|
||||
require.Equal(t, Value("e"), m.delete(5))
|
||||
|
||||
// Check the remaining keys and values are correct.
|
||||
require.Equal(t, []Key{2, 4}, m.keys())
|
||||
require.Equal(t, []Value{"b", "d"}, m.values())
|
||||
|
||||
// Overwrite some data.
|
||||
m.insert(Data{key: 2, value: "two"})
|
||||
m.insert(Data{key: 4, value: "four"})
|
||||
|
||||
// Check the remaining keys and values are correct.
|
||||
require.Equal(t, []Key{2, 4}, m.keys())
|
||||
require.Equal(t, []Value{"two", "four"}, m.values())
|
||||
}
|
||||
@ -20,12 +20,14 @@ import (
|
||||
)
|
||||
|
||||
type smtpSession struct {
|
||||
client *liteapi.Client
|
||||
username string
|
||||
addresses []liteapi.Address
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
settings liteapi.MailSettings
|
||||
client *liteapi.Client
|
||||
|
||||
username string
|
||||
emails map[string]string
|
||||
settings liteapi.MailSettings
|
||||
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
|
||||
from string
|
||||
to map[string]struct{}
|
||||
@ -34,18 +36,20 @@ type smtpSession struct {
|
||||
func newSMTPSession(
|
||||
client *liteapi.Client,
|
||||
username string,
|
||||
addresses []liteapi.Address,
|
||||
addresses map[string]string,
|
||||
settings liteapi.MailSettings,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
settings liteapi.MailSettings,
|
||||
) *smtpSession {
|
||||
return &smtpSession{
|
||||
client: client,
|
||||
username: username,
|
||||
addresses: addresses,
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
settings: settings,
|
||||
client: client,
|
||||
|
||||
username: username,
|
||||
emails: addresses,
|
||||
settings: settings,
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
|
||||
from: "",
|
||||
to: make(map[string]struct{}),
|
||||
@ -86,15 +90,15 @@ func (session *smtpSession) Mail(from string, opts smtp.MailOptions) error {
|
||||
return ErrNotImplemented
|
||||
}
|
||||
|
||||
idx := xslices.IndexFunc(session.addresses, func(address liteapi.Address) bool {
|
||||
return strings.EqualFold(address.Email, from)
|
||||
})
|
||||
|
||||
if idx < 0 {
|
||||
return ErrInvalidReturnPath
|
||||
for addrID, email := range session.emails {
|
||||
if strings.EqualFold(from, email) {
|
||||
session.from = addrID
|
||||
}
|
||||
}
|
||||
|
||||
session.from = session.addresses[idx].ID
|
||||
if session.from == "" {
|
||||
return ErrInvalidReturnPath
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
@ -129,10 +133,10 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
|
||||
addrKR, ok := session.addrKRs[session.from]
|
||||
if !ok {
|
||||
return ErrMissingAddressKey
|
||||
return ErrMissingAddrKey
|
||||
}
|
||||
|
||||
addrKR, err := addrKR.FirstKey()
|
||||
addrKey, err := addrKR.FirstKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get first key: %w", err)
|
||||
}
|
||||
@ -143,7 +147,7 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
}
|
||||
|
||||
if session.settings.AttachPublicKey == liteapi.AttachPublicKeyEnabled {
|
||||
key, err := addrKR.GetKey(0)
|
||||
key, err := addrKey.GetKey(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get user public key: %w", err)
|
||||
}
|
||||
@ -153,7 +157,7 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
return fmt.Errorf("failed to get user public key: %w", err)
|
||||
}
|
||||
|
||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKR.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||
parser.AttachPublicKey(pubKey, fmt.Sprintf("publickey - %v - %v", addrKey.GetIdentities()[0].Name, key.GetFingerprint()[:8]))
|
||||
}
|
||||
|
||||
message, err := message.ParseWithParser(parser)
|
||||
@ -161,7 +165,7 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
return fmt.Errorf("failed to parse message: %w", err)
|
||||
}
|
||||
|
||||
draft, attKeys, err := session.createDraft(ctx, addrKR, message)
|
||||
draft, attKeys, err := session.createDraft(ctx, addrKey, message)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create draft: %w", err)
|
||||
}
|
||||
@ -171,7 +175,7 @@ func (session *smtpSession) Data(r io.Reader) error {
|
||||
return fmt.Errorf("failed to get recipients: %w", err)
|
||||
}
|
||||
|
||||
req, err := createSendReq(addrKR, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
|
||||
req, err := createSendReq(addrKey, message.MIMEBody, message.RichBody, message.PlainBody, recipients, attKeys)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create packages: %w", err)
|
||||
}
|
||||
|
||||
@ -4,57 +4,34 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const chunkSize = 1 << 20
|
||||
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
user.notifyCh <- events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
}
|
||||
|
||||
if err := user.syncLabels(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
user.notifyCh <- events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
}
|
||||
|
||||
if err := user.vault.SetSync(true); err != nil {
|
||||
return fmt.Errorf("failed to update sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (user *User) syncLabels(ctx context.Context) error {
|
||||
func (user *User) syncLabels(ctx context.Context, addrIDs ...string) error {
|
||||
// Sync the system folders.
|
||||
system, err := user.client.GetLabels(ctx, liteapi.LabelTypeSystem)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, label := range system {
|
||||
user.updateCh <- newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name)
|
||||
for _, label := range xslices.Filter(system, func(label liteapi.Label) bool { return wantLabelID(label.ID) }) {
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newSystemMailboxCreatedUpdate(imap.LabelID(label.ID), label.Name))
|
||||
}
|
||||
}
|
||||
|
||||
// Create Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
|
||||
for _, prefix := range []string{folderPrefix, labelPrefix} {
|
||||
user.updateCh <- newPlaceHolderMailboxCreatedUpdate(prefix)
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newPlaceHolderMailboxCreatedUpdate(prefix))
|
||||
}
|
||||
}
|
||||
|
||||
// Sync the API folders.
|
||||
@ -64,7 +41,9 @@ func (user *User) syncLabels(ctx context.Context) error {
|
||||
}
|
||||
|
||||
for _, folder := range folders {
|
||||
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path})
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(folder.ID), []string{folderPrefix, folder.Path}))
|
||||
}
|
||||
}
|
||||
|
||||
// Sync the API labels.
|
||||
@ -74,7 +53,9 @@ func (user *User) syncLabels(ctx context.Context) error {
|
||||
}
|
||||
|
||||
for _, label := range labels {
|
||||
user.updateCh <- newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path})
|
||||
for _, addrID := range addrIDs {
|
||||
user.updateCh[addrID].Enqueue(newMailboxCreatedUpdate(imap.LabelID(label.ID), []string{labelPrefix, label.Path}))
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -84,27 +65,53 @@ func (user *User) syncMessages(ctx context.Context) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
// Determine which messages to sync.
|
||||
// TODO: This needs to be done better using the new API route to retrieve just the message IDs.
|
||||
metadata, err := user.client.GetAllMessageMetadata(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// If in split mode, we need to send each message to a different IMAP connector.
|
||||
isSplitMode := user.vault.AddressMode() == vault.SplitMode
|
||||
|
||||
// Collect the build requests -- we need:
|
||||
// - the message ID to build,
|
||||
// - the keyring to decrypt the message,
|
||||
// - and the address to send the message to (for split mode).
|
||||
requests := xslices.Map(metadata, func(metadata liteapi.MessageMetadata) request {
|
||||
var addressID string
|
||||
|
||||
if isSplitMode {
|
||||
addressID = metadata.AddressID
|
||||
} else {
|
||||
addressID = user.apiAddrs.primary()
|
||||
}
|
||||
|
||||
return request{
|
||||
messageID: metadata.ID,
|
||||
addressID: addressID,
|
||||
addrKR: user.addrKRs[metadata.AddressID],
|
||||
}
|
||||
})
|
||||
|
||||
flusher := newFlusher(user.ID(), user.updateCh, user.notifyCh, len(metadata), chunkSize)
|
||||
defer flusher.flush()
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher)
|
||||
|
||||
for addrID, updateCh := range user.updateCh {
|
||||
flusher := newFlusher(user.ID(), updateCh, user.eventCh, len(requests), chunkSize)
|
||||
defer flusher.flush()
|
||||
|
||||
flushers[addrID] = flusher
|
||||
}
|
||||
|
||||
// Build the messages and send them to the correct flusher.
|
||||
if err := user.builder.Process(ctx, requests, func(req request, res *imap.MessageCreated, err error) error {
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to build message %s: %w", req.messageID, err)
|
||||
}
|
||||
|
||||
flusher.push(res)
|
||||
flushers[req.addressID].push(res)
|
||||
|
||||
return nil
|
||||
}); err != nil {
|
||||
@ -114,95 +121,15 @@ func (user *User) syncMessages(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type flusher struct {
|
||||
userID string
|
||||
func (user *User) syncWait() {
|
||||
for _, updateCh := range user.updateCh {
|
||||
waiter := imap.NewNoop()
|
||||
defer waiter.Wait()
|
||||
|
||||
updates []*imap.MessageCreated
|
||||
updateCh chan<- imap.Update
|
||||
notifyCh chan<- events.Event
|
||||
maxChunkSize int
|
||||
curChunkSize int
|
||||
|
||||
count int
|
||||
total int
|
||||
start time.Time
|
||||
|
||||
pushLock sync.Mutex
|
||||
}
|
||||
|
||||
func newFlusher(userID string, updateCh chan<- imap.Update, notifyCh chan<- events.Event, total, maxChunkSize int) *flusher {
|
||||
return &flusher{
|
||||
userID: userID,
|
||||
updateCh: updateCh,
|
||||
notifyCh: notifyCh,
|
||||
maxChunkSize: maxChunkSize,
|
||||
total: total,
|
||||
start: time.Now(),
|
||||
updateCh.Enqueue(waiter)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) push(update *imap.MessageCreated) {
|
||||
f.pushLock.Lock()
|
||||
defer f.pushLock.Unlock()
|
||||
|
||||
f.updates = append(f.updates, update)
|
||||
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxChunkSize {
|
||||
f.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) flush() {
|
||||
if len(f.updates) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
f.count += len(f.updates)
|
||||
f.updateCh <- imap.NewMessagesCreated(f.updates...)
|
||||
f.notifyCh <- newSyncProgress(f.userID, f.count, f.total, f.start)
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
}
|
||||
|
||||
func newSyncProgress(userID string, count, total int, start time.Time) events.SyncProgress {
|
||||
return events.SyncProgress{
|
||||
UserID: userID,
|
||||
Progress: float64(count) / float64(total),
|
||||
Elapsed: time.Since(start),
|
||||
Remaining: time.Since(start) * time.Duration(total-count) / time.Duration(count),
|
||||
}
|
||||
}
|
||||
|
||||
func getMessageCreatedUpdate(message liteapi.Message, literal []byte) (*imap.MessageCreated, error) {
|
||||
parsedMessage, err := imap.NewParsedMessage(literal)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
flags := imap.NewFlagSet()
|
||||
|
||||
if !message.Unread {
|
||||
flags = flags.Add(imap.FlagSeen)
|
||||
}
|
||||
|
||||
if slices.Contains(message.LabelIDs, liteapi.StarredLabel) {
|
||||
flags = flags.Add(imap.FlagFlagged)
|
||||
}
|
||||
|
||||
imapMessage := imap.Message{
|
||||
ID: imap.MessageID(message.ID),
|
||||
Flags: flags,
|
||||
Date: time.Unix(message.Time, 0),
|
||||
}
|
||||
|
||||
return &imap.MessageCreated{
|
||||
Message: imapMessage,
|
||||
Literal: literal,
|
||||
LabelIDs: imapLabelIDs(filterLabelIDs(message.LabelIDs)),
|
||||
ParsedMessage: parsedMessage,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newSystemMailboxCreatedUpdate(labelID imap.LabelID, labelName string) *imap.MailboxCreated {
|
||||
if strings.EqualFold(labelName, imap.Inbox) {
|
||||
labelName = imap.Inbox
|
||||
@ -237,18 +164,12 @@ func newMailboxCreatedUpdate(labelID imap.LabelID, labelName []string) *imap.Mai
|
||||
})
|
||||
}
|
||||
|
||||
func filterLabelIDs(labelIDs []string) []string {
|
||||
var filteredLabelIDs []string
|
||||
func wantLabelID(labelID string) bool {
|
||||
switch labelID {
|
||||
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
|
||||
return false
|
||||
|
||||
for _, labelID := range labelIDs {
|
||||
switch labelID {
|
||||
case liteapi.AllDraftsLabel, liteapi.AllSentLabel, liteapi.OutboxLabel:
|
||||
// ... skip ...
|
||||
|
||||
default:
|
||||
filteredLabelIDs = append(filteredLabelIDs, labelID)
|
||||
}
|
||||
default:
|
||||
return true
|
||||
}
|
||||
|
||||
return filteredLabelIDs
|
||||
}
|
||||
|
||||
13
internal/user/types.go
Normal file
13
internal/user/types.go
Normal file
@ -0,0 +1,13 @@
|
||||
package user
|
||||
|
||||
import "reflect"
|
||||
|
||||
func mapTo[From, To any](from []From) []To {
|
||||
to := make([]To, 0, len(from))
|
||||
|
||||
for _, from := range from {
|
||||
to = append(to, reflect.ValueOf(from).Convert(reflect.TypeOf(to).Elem()).Interface().(To))
|
||||
}
|
||||
|
||||
return to
|
||||
}
|
||||
20
internal/user/types_test.go
Normal file
20
internal/user/types_test.go
Normal file
@ -0,0 +1,20 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestToType(t *testing.T) {
|
||||
type myString string
|
||||
|
||||
// Slices of different types are not equal.
|
||||
require.NotEqual(t, []myString{"a", "b", "c"}, []string{"a", "b", "c"})
|
||||
|
||||
// But converting them to the same type makes them equal.
|
||||
require.Equal(t, []myString{"a", "b", "c"}, mapTo[string, myString]([]string{"a", "b", "c"}))
|
||||
|
||||
// The conversion can happen in the other direction too.
|
||||
require.Equal(t, []string{"a", "b", "c"}, mapTo[myString, string]([]myString{"a", "b", "c"}))
|
||||
}
|
||||
@ -2,19 +2,22 @@ package user
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon"
|
||||
"github.com/ProtonMail/gluon/connector"
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/pool"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/emersion/go-smtp"
|
||||
"github.com/sirupsen/logrus"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
@ -23,40 +26,38 @@ var (
|
||||
DefaultEventJitter = 20 * time.Second
|
||||
)
|
||||
|
||||
// TODO: Is it bad to store the key pass in the user? Any worse than storing private keys?
|
||||
type User struct {
|
||||
vault *vault.User
|
||||
client *liteapi.Client
|
||||
builder *pool.Pool[request, *imap.MessageCreated]
|
||||
eventCh *queue.QueuedChannel[events.Event]
|
||||
|
||||
apiUser liteapi.User
|
||||
addresses []liteapi.Address
|
||||
settings liteapi.MailSettings
|
||||
|
||||
notifyCh chan events.Event
|
||||
updateCh chan imap.Update
|
||||
|
||||
apiUser liteapi.User
|
||||
apiAddrs *addrList
|
||||
userKR *crypto.KeyRing
|
||||
addrKRs map[string]*crypto.KeyRing
|
||||
imapConn *imapConnector
|
||||
settings liteapi.MailSettings
|
||||
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update]
|
||||
syncWG gluon.WaitGroup
|
||||
}
|
||||
|
||||
func New(
|
||||
ctx context.Context,
|
||||
vault *vault.User,
|
||||
encVault *vault.User,
|
||||
client *liteapi.Client,
|
||||
apiUser liteapi.User,
|
||||
apiAddrs []liteapi.Address,
|
||||
userKR *crypto.KeyRing,
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
) (*User, error) {
|
||||
if vault.EventID() == "" {
|
||||
if encVault.EventID() == "" {
|
||||
eventID, err := client.GetLatestEventID(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := vault.SetEventID(eventID); err != nil {
|
||||
if err := encVault.SetEventID(eventID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
@ -67,19 +68,29 @@ func New(
|
||||
}
|
||||
|
||||
user := &User{
|
||||
apiUser: apiUser,
|
||||
addresses: apiAddrs,
|
||||
settings: settings,
|
||||
|
||||
vault: vault,
|
||||
vault: encVault,
|
||||
client: client,
|
||||
builder: newBuilder(client, runtime.NumCPU()*runtime.NumCPU(), runtime.NumCPU()*runtime.NumCPU()),
|
||||
eventCh: queue.NewQueuedChannel[events.Event](0, 0),
|
||||
|
||||
notifyCh: make(chan events.Event),
|
||||
updateCh: make(chan imap.Update),
|
||||
apiUser: apiUser,
|
||||
apiAddrs: newAddrList(apiAddrs),
|
||||
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
userKR: userKR,
|
||||
addrKRs: addrKRs,
|
||||
settings: settings,
|
||||
|
||||
updateCh: make(map[string]*queue.QueuedChannel[imap.Update]),
|
||||
}
|
||||
|
||||
// Initialize update channels for each of the user's addresses.
|
||||
for _, addrID := range user.apiAddrs.addrIDs() {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
// If in combined mode, we only need one update channel.
|
||||
if encVault.AddressMode() == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
// When we receive an auth object, we update it in the store.
|
||||
@ -93,111 +104,234 @@ func New(
|
||||
// When we are deauthorized, we send a deauth event to the notify channel.
|
||||
// Bridge will catch this and log the user out.
|
||||
client.AddDeauthHandler(func() {
|
||||
user.notifyCh <- events.UserDeauth{
|
||||
user.eventCh.Enqueue(events.UserDeauth{
|
||||
UserID: user.ID(),
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
// When we receive an API event, we attempt to handle it. If successful, we send the event to the event channel.
|
||||
// When we receive an API event, we attempt to handle it.
|
||||
// If successful, we update the event ID in the vault.
|
||||
go func() {
|
||||
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, vault.EventID()).Subscribe() {
|
||||
if err := user.handleAPIEvent(event); err != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
for event := range user.client.NewEventStreamer(DefaultEventPeriod, DefaultEventJitter, encVault.EventID()).Subscribe() {
|
||||
if err := user.handleAPIEvent(ctx, event); err != nil {
|
||||
logrus.WithError(err).Error("Failed to handle event")
|
||||
} else {
|
||||
if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update event ID")
|
||||
}
|
||||
} else if err := user.vault.SetEventID(event.EventID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to update event ID")
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// TODO: Use a proper sync manager! (if partial sync, pickup from where we last stopped)
|
||||
if !vault.HasSync() {
|
||||
go user.sync(context.Background())
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// ID returns the user's ID.
|
||||
func (user *User) ID() string {
|
||||
return user.apiUser.ID
|
||||
}
|
||||
|
||||
// Name returns the user's username.
|
||||
func (user *User) Name() string {
|
||||
return user.apiUser.Name
|
||||
}
|
||||
|
||||
// Match matches the given query against the user's username and email addresses.
|
||||
func (user *User) Match(query string) bool {
|
||||
if query == user.Name() {
|
||||
if query == user.apiUser.Name {
|
||||
return true
|
||||
}
|
||||
|
||||
return slices.Contains(user.Addresses(), query)
|
||||
return slices.Contains(user.apiAddrs.emails(), query)
|
||||
}
|
||||
|
||||
func (user *User) Addresses() []string {
|
||||
return xslices.Map(
|
||||
sort(user.addresses, func(a, b liteapi.Address) bool {
|
||||
return a.Order < b.Order
|
||||
}),
|
||||
func(address liteapi.Address) string {
|
||||
return address.Email
|
||||
},
|
||||
)
|
||||
// Emails returns all the user's email addresses.
|
||||
func (user *User) Emails() []string {
|
||||
return user.apiAddrs.emails()
|
||||
}
|
||||
|
||||
func (user *User) GluonID() string {
|
||||
return user.vault.GluonID()
|
||||
// GetAddressMode returns the user's current address mode.
|
||||
func (user *User) GetAddressMode() vault.AddressMode {
|
||||
return user.vault.AddressMode()
|
||||
}
|
||||
|
||||
// SetAddressMode sets the user's address mode.
|
||||
func (user *User) SetAddressMode(ctx context.Context, mode vault.AddressMode) error {
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Close()
|
||||
}
|
||||
|
||||
user.updateCh = make(map[string]*queue.QueuedChannel[imap.Update])
|
||||
|
||||
for _, addrID := range user.apiAddrs.addrIDs() {
|
||||
user.updateCh[addrID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
|
||||
if mode == vault.CombinedMode {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if err := user.vault.SetAddressMode(mode); err != nil {
|
||||
return fmt.Errorf("failed to set address mode: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetGluonIDs returns the users gluon IDs.
|
||||
func (user *User) GetGluonIDs() map[string]string {
|
||||
return user.vault.GetGluonIDs()
|
||||
}
|
||||
|
||||
// GetGluonID returns the gluon ID for the given address, if present.
|
||||
func (user *User) GetGluonID(addrID string) (string, bool) {
|
||||
gluonID, ok := user.vault.GetGluonIDs()[addrID]
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return gluonID, true
|
||||
}
|
||||
|
||||
// SetGluonID sets the gluon ID for the given address.
|
||||
func (user *User) SetGluonID(addrID, gluonID string) error {
|
||||
return user.vault.SetGluonID(addrID, gluonID)
|
||||
}
|
||||
|
||||
// GluonKey returns the user's gluon key from the vault.
|
||||
func (user *User) GluonKey() []byte {
|
||||
return user.vault.GluonKey()
|
||||
}
|
||||
|
||||
// BridgePass returns the user's bridge password, used for authentication over SMTP and IMAP.
|
||||
func (user *User) BridgePass() string {
|
||||
return user.vault.BridgePass()
|
||||
}
|
||||
|
||||
// UsedSpace returns the total space used by the user on the API.
|
||||
func (user *User) UsedSpace() int {
|
||||
return user.apiUser.UsedSpace
|
||||
}
|
||||
|
||||
// MaxSpace returns the amount of space the user can use on the API.
|
||||
func (user *User) MaxSpace() int {
|
||||
return user.apiUser.MaxSpace
|
||||
}
|
||||
|
||||
// GetNotifyCh returns a channel which notifies of events happening to the user (such as deauth, address change)
|
||||
func (user *User) GetNotifyCh() <-chan events.Event {
|
||||
return user.notifyCh
|
||||
// HasSync returns whether the user has finished syncing.
|
||||
func (user *User) HasSync() bool {
|
||||
return user.vault.HasSync()
|
||||
}
|
||||
|
||||
func (user *User) NewGluonConnector(ctx context.Context) (connector.Connector, error) {
|
||||
if user.imapConn != nil {
|
||||
if err := user.imapConn.Close(ctx); err != nil {
|
||||
return nil, err
|
||||
// AbortSync aborts any ongoing sync.
|
||||
// TODO: This should abort the sync rather than just waiting.
|
||||
// Should probably be done automatically when one of the user's IMAP connectors is closed.
|
||||
func (user *User) AbortSync(ctx context.Context) error {
|
||||
user.syncWG.Wait()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// DoSync performs a sync for the user.
|
||||
func (user *User) DoSync(ctx context.Context) <-chan error {
|
||||
errCh := queue.NewQueuedChannel[error](0, 0)
|
||||
|
||||
user.syncWG.Go(func() {
|
||||
defer errCh.Close()
|
||||
|
||||
user.eventCh.Enqueue(events.SyncStarted{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
|
||||
errCh.Enqueue(func() error {
|
||||
if err := user.syncLabels(ctx, maps.Keys(user.updateCh)...); err != nil {
|
||||
return fmt.Errorf("failed to sync labels: %w", err)
|
||||
}
|
||||
|
||||
if err := user.syncMessages(ctx); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
|
||||
user.syncWait()
|
||||
|
||||
if err := user.vault.SetSync(true); err != nil {
|
||||
return fmt.Errorf("failed to set sync status: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}())
|
||||
|
||||
user.eventCh.Enqueue(events.SyncFinished{
|
||||
UserID: user.ID(),
|
||||
})
|
||||
})
|
||||
|
||||
return errCh.GetChannel()
|
||||
}
|
||||
|
||||
// GetEventCh returns a channel which notifies of events happening to the user (such as deauth, address change)
|
||||
func (user *User) GetEventCh() <-chan events.Event {
|
||||
return user.eventCh.GetChannel()
|
||||
}
|
||||
|
||||
// NewIMAPConnector returns an IMAP connector for the given address.
|
||||
// If not in split mode, this function returns an error.
|
||||
func (user *User) NewIMAPConnector(addrID string) (connector.Connector, error) {
|
||||
var emails []string
|
||||
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
if addrID != user.apiAddrs.primary() {
|
||||
return nil, fmt.Errorf("cannot create IMAP connector for non-primary address in combined mode")
|
||||
}
|
||||
|
||||
emails = user.apiAddrs.emails()
|
||||
|
||||
case vault.SplitMode:
|
||||
emails = []string{user.apiAddrs.email(addrID)}
|
||||
}
|
||||
|
||||
user.imapConn = newIMAPConnector(user.client, user.updateCh, user.Addresses(), user.vault.BridgePass())
|
||||
|
||||
return user.imapConn, nil
|
||||
return newIMAPConnector(
|
||||
user.client,
|
||||
user.updateCh[addrID].GetChannel(),
|
||||
user.vault.BridgePass(),
|
||||
emails...,
|
||||
), nil
|
||||
}
|
||||
|
||||
func (user *User) NewSMTPSession(username string) (smtp.Session, error) {
|
||||
return newSMTPSession(user.client, username, user.addresses, user.userKR, user.addrKRs, user.settings), nil
|
||||
// NewIMAPConnectors returns IMAP connectors for each of the user's addresses.
|
||||
// In combined mode, this is just the user's primary address.
|
||||
// In split mode, this is all the user's addresses.
|
||||
func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
imapConn := make(map[string]connector.Connector)
|
||||
|
||||
for addrID := range user.updateCh {
|
||||
conn, err := user.NewIMAPConnector(addrID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create IMAP connector: %w", err)
|
||||
}
|
||||
|
||||
imapConn[addrID] = conn
|
||||
}
|
||||
|
||||
return imapConn, nil
|
||||
}
|
||||
|
||||
// NewSMTPSession returns an SMTP session for the user.
|
||||
func (user *User) NewSMTPSession(username string) smtp.Session {
|
||||
return newSMTPSession(user.client, username, user.apiAddrs.addrMap(), user.settings, user.userKR, user.addrKRs)
|
||||
}
|
||||
|
||||
// Logout logs the user out from the API.
|
||||
func (user *User) Logout(ctx context.Context) error {
|
||||
return user.client.AuthDelete(ctx)
|
||||
}
|
||||
|
||||
// Close closes ongoing connections and cleans up resources.
|
||||
func (user *User) Close(ctx context.Context) error {
|
||||
// Close the user's IMAP connectors.
|
||||
if user.imapConn != nil {
|
||||
if err := user.imapConn.Close(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
// Wait for ongoing syncs to finish.
|
||||
user.syncWG.Wait()
|
||||
|
||||
// Close the user's message builder.
|
||||
user.builder.Done()
|
||||
@ -205,15 +339,13 @@ func (user *User) Close(ctx context.Context) error {
|
||||
// Close the user's API client.
|
||||
user.client.Close()
|
||||
|
||||
// Close the user's update channels.
|
||||
for _, updateCh := range user.updateCh {
|
||||
updateCh.Close()
|
||||
}
|
||||
|
||||
// Close the user's notify channel.
|
||||
close(user.notifyCh)
|
||||
user.eventCh.Close()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// sort returns the slice, sorted by the given callback.
|
||||
func sort[T any](slice []T, less func(a, b T) bool) []T {
|
||||
slices.SortFunc(slice, less)
|
||||
|
||||
return slice
|
||||
}
|
||||
|
||||
162
internal/user/user_test.go
Normal file
162
internal/user/user_test.go
Normal file
@ -0,0 +1,162 @@
|
||||
package user_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/certs"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/events"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/user"
|
||||
"github.com/ProtonMail/proton-bridge/v2/internal/vault"
|
||||
"github.com/ProtonMail/proton-bridge/v2/tests"
|
||||
"github.com/bradenaw/juniper/iterator"
|
||||
"github.com/emersion/go-imap"
|
||||
"github.com/emersion/go-imap/client"
|
||||
"github.com/stretchr/testify/require"
|
||||
"gitlab.protontech.ch/go/liteapi"
|
||||
"gitlab.protontech.ch/go/liteapi/server"
|
||||
"gitlab.protontech.ch/go/liteapi/server/account"
|
||||
)
|
||||
|
||||
func init() {
|
||||
user.DefaultEventPeriod = 100 * time.Millisecond
|
||||
user.DefaultEventJitter = 0
|
||||
account.GenerateKey = tests.FastGenerateKey
|
||||
certs.GenerateCert = tests.FastGenerateCert
|
||||
}
|
||||
|
||||
func TestUser_Data(t *testing.T) {
|
||||
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me", "alias@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
|
||||
// User's ID should be correct.
|
||||
require.Equal(t, userID, user.ID())
|
||||
|
||||
// User's name should be correct.
|
||||
require.Equal(t, "username", user.Name())
|
||||
|
||||
// User's email should be correct.
|
||||
require.ElementsMatch(t, []string{"email@pm.me", "alias@pm.me"}, user.Emails())
|
||||
|
||||
// By default, user should be in combined mode.
|
||||
require.Equal(t, vault.CombinedMode, user.GetAddressMode())
|
||||
|
||||
// By default, user should have a non-empty bridge password.
|
||||
require.NotEmpty(t, user.BridgePass())
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_Sync(t *testing.T) {
|
||||
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
|
||||
// Get the user's IMAP connectors.
|
||||
imapConn, err := user.NewIMAPConnectors()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Pretend to be gluon applying all the updates.
|
||||
go func() {
|
||||
for _, imapConn := range imapConn {
|
||||
for update := range imapConn.GetUpdates() {
|
||||
update.Done()
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
// Trigger a user sync.
|
||||
errCh := user.DoSync(ctx)
|
||||
|
||||
// User starts a sync at startup.
|
||||
require.IsType(t, events.SyncStarted{}, <-user.GetEventCh())
|
||||
|
||||
// User finishes a sync at startup.
|
||||
require.IsType(t, events.SyncFinished{}, <-user.GetEventCh())
|
||||
|
||||
// The sync completes without error.
|
||||
require.NoError(t, <-errCh)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func TestUser_Deauth(t *testing.T) {
|
||||
withAPI(t, context.Background(), "username", "password", []string{"email@pm.me"}, func(ctx context.Context, s *server.Server, userID string, addrIDs []string) {
|
||||
withUser(t, ctx, s.GetHostURL(), "username", "password", func(user *user.User) {
|
||||
eventCh := user.GetEventCh()
|
||||
|
||||
// Revoke the user's auth token.
|
||||
require.NoError(t, s.RevokeUser(userID))
|
||||
|
||||
// The user should eventually be logged out.
|
||||
require.Eventually(t, func() bool { _, ok := (<-eventCh).(events.UserDeauth); return ok }, 5*time.Second, 100*time.Millisecond)
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
func withAPI(t *testing.T, ctx context.Context, username, password string, emails []string, fn func(context.Context, *server.Server, string, []string)) {
|
||||
server := server.New()
|
||||
defer server.Close()
|
||||
|
||||
var addrIDs []string
|
||||
|
||||
userID, addrID, err := server.AddUser(username, password, emails[0])
|
||||
require.NoError(t, err)
|
||||
|
||||
addrIDs = append(addrIDs, addrID)
|
||||
|
||||
for _, email := range emails[1:] {
|
||||
addrID, err := server.AddAddress(userID, email, password)
|
||||
require.NoError(t, err)
|
||||
|
||||
addrIDs = append(addrIDs, addrID)
|
||||
}
|
||||
|
||||
fn(ctx, server, userID, addrIDs)
|
||||
}
|
||||
|
||||
func withUser(t *testing.T, ctx context.Context, apiURL, username, password string, fn func(*user.User)) {
|
||||
c, apiAuth, err := liteapi.New(liteapi.WithHostURL(apiURL)).NewClientWithLogin(ctx, username, password)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, c.Close()) }()
|
||||
|
||||
apiUser, apiAddrs, userKR, addrKRs, passphrase, err := c.Unlock(ctx, []byte(password))
|
||||
require.NoError(t, err)
|
||||
|
||||
vault, corrupt, err := vault.New(t.TempDir(), t.TempDir(), []byte("my secret key"))
|
||||
require.NoError(t, err)
|
||||
require.False(t, corrupt)
|
||||
|
||||
vaultUser, err := vault.AddUser(apiUser.ID, username, apiAuth.UID, apiAuth.RefreshToken, passphrase)
|
||||
require.NoError(t, err)
|
||||
|
||||
user, err := user.New(ctx, vaultUser, c, apiUser, apiAddrs, userKR, addrKRs)
|
||||
require.NoError(t, err)
|
||||
defer func() { require.NoError(t, user.Close(ctx)) }()
|
||||
|
||||
fn(user)
|
||||
}
|
||||
|
||||
func withIMAPClient(t *testing.T, addr string, fn func(*client.Client)) {
|
||||
c, err := client.Dial(addr)
|
||||
require.NoError(t, err)
|
||||
defer c.Close()
|
||||
|
||||
fn(c)
|
||||
}
|
||||
|
||||
func fetch(t *testing.T, c *client.Client, seqset string, items ...imap.FetchItem) []*imap.Message {
|
||||
msgCh := make(chan *imap.Message)
|
||||
|
||||
go func() {
|
||||
require.NoError(t, c.Fetch(must(imap.ParseSeqSet(seqset)), items, msgCh))
|
||||
}()
|
||||
|
||||
return iterator.Collect(iterator.Chan(msgCh))
|
||||
}
|
||||
|
||||
func must[T any](v T, err error) T {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
return v
|
||||
}
|
||||
Reference in New Issue
Block a user