mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-27 20:16:44 +00:00
chore: merge release/perth_narrows into devel
This commit is contained in:
@ -18,6 +18,7 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
@ -67,6 +68,10 @@ func (user *User) handleAPIEvent(ctx context.Context, event proton.Event) error
|
||||
}
|
||||
}
|
||||
|
||||
if event.UsedSpace != nil {
|
||||
user.handleUsedSpaceChange(*event.UsedSpace)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -194,6 +199,12 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
||||
|
||||
user.apiAddrs[event.Address.ID] = event.Address
|
||||
|
||||
// If the address is disabled.
|
||||
if event.Address.Status != proton.AddressStatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
// If the address is enabled, we need to hook it up to the update channels.
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
@ -220,6 +231,10 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
|
||||
|
||||
// Perform the sync in an RLock.
|
||||
return safe.RLockRet(func() error {
|
||||
if event.Address.Status != proton.AddressStatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
if err := syncLabels(ctx, user.apiLabels, user.updateCh[event.Address.ID]); err != nil {
|
||||
return fmt.Errorf("failed to sync labels to new address: %w", err)
|
||||
@ -237,18 +252,58 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
|
||||
"email": logging.Sensitive(event.Address.Email),
|
||||
}).Info("Handling address updated event")
|
||||
|
||||
if _, ok := user.apiAddrs[event.Address.ID]; !ok {
|
||||
oldAddr, ok := user.apiAddrs[event.Address.ID]
|
||||
if !ok {
|
||||
user.log.Debugf("Address %q does not exist", event.Address.ID)
|
||||
return nil
|
||||
}
|
||||
|
||||
user.apiAddrs[event.Address.ID] = event.Address
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
switch {
|
||||
// If the address was newly enabled:
|
||||
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
|
||||
switch user.vault.AddressMode() {
|
||||
case vault.CombinedMode:
|
||||
primAddr, err := getAddrIdx(user.apiAddrs, 0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get primary address: %w", err)
|
||||
}
|
||||
|
||||
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
|
||||
|
||||
case vault.SplitMode:
|
||||
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressEnabled{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
// If the address was newly disabled:
|
||||
case oldAddr.Status == proton.AddressStatusEnabled && event.Address.Status != proton.AddressStatusEnabled:
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.ID].CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
delete(user.updateCh, event.ID)
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDisabled{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
|
||||
// Otherwise it's just an update:
|
||||
default:
|
||||
user.eventCh.Enqueue(events.UserAddressUpdated{
|
||||
UserID: user.apiUser.ID,
|
||||
AddressID: event.Address.ID,
|
||||
Email: event.Address.Email,
|
||||
})
|
||||
}
|
||||
|
||||
return nil
|
||||
}, user.apiAddrsLock, user.updateChLock)
|
||||
@ -264,12 +319,20 @@ func (user *User) handleDeleteAddressEvent(_ context.Context, event proton.Addre
|
||||
return nil
|
||||
}
|
||||
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.ID].CloseAndDiscardQueued()
|
||||
delete(user.updateCh, event.ID)
|
||||
delete(user.apiAddrs, event.ID)
|
||||
|
||||
// If the address was disabled to begin with, we don't need to do anything.
|
||||
if addr.Status != proton.AddressStatusEnabled {
|
||||
return nil
|
||||
}
|
||||
|
||||
delete(user.apiAddrs, event.ID)
|
||||
// Otherwise, in split mode, drop the update queue.
|
||||
if user.vault.AddressMode() == vault.SplitMode {
|
||||
user.updateCh[event.ID].CloseAndDiscardQueued()
|
||||
}
|
||||
|
||||
// And in either mode, remove the address from the update channel map.
|
||||
delete(user.updateCh, event.ID)
|
||||
|
||||
user.eventCh.Enqueue(events.UserAddressDeleted{
|
||||
UserID: user.apiUser.ID,
|
||||
@ -356,25 +419,51 @@ func (user *User) handleUpdateLabelEvent(ctx context.Context, event proton.Label
|
||||
"name": logging.Sensitive(event.Label.Name),
|
||||
}).Info("Handling label updated event")
|
||||
|
||||
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
|
||||
if _, ok := user.apiLabels[event.Label.ID]; ok {
|
||||
user.apiLabels[event.Label.ID] = event.Label
|
||||
}
|
||||
stack := []proton.Label{event.Label}
|
||||
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := imap.NewMailboxUpdated(
|
||||
imap.MailboxID(event.ID),
|
||||
getMailboxName(event.Label),
|
||||
)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
for len(stack) > 0 {
|
||||
label := stack[0]
|
||||
stack = stack[1:]
|
||||
|
||||
user.eventCh.Enqueue(events.UserLabelUpdated{
|
||||
UserID: user.apiUser.ID,
|
||||
LabelID: event.Label.ID,
|
||||
Name: event.Label.Name,
|
||||
})
|
||||
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
|
||||
if _, ok := user.apiLabels[label.ID]; ok {
|
||||
user.apiLabels[label.ID] = event.Label
|
||||
}
|
||||
|
||||
// API doesn't notify us that the path has changed. We need to fetch it again.
|
||||
apiLabel, err := user.client.GetLabel(ctx, label.ID, label.Type)
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
user.log.WithError(apiErr).Warn("Failed to get label: label does not exist")
|
||||
continue
|
||||
} else if err != nil {
|
||||
return nil, fmt.Errorf("failed to get label %q: %w", label.ID, err)
|
||||
}
|
||||
|
||||
// Update the label in the map.
|
||||
user.apiLabels[apiLabel.ID] = apiLabel
|
||||
|
||||
// Notify the IMAP clients.
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := imap.NewMailboxUpdated(
|
||||
imap.MailboxID(apiLabel.ID),
|
||||
getMailboxName(apiLabel),
|
||||
)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
|
||||
user.eventCh.Enqueue(events.UserLabelUpdated{
|
||||
UserID: user.apiUser.ID,
|
||||
LabelID: apiLabel.ID,
|
||||
Name: apiLabel.Name,
|
||||
})
|
||||
|
||||
children := xslices.Filter(maps.Values(user.apiLabels), func(other proton.Label) bool {
|
||||
return other.ParentID == label.ID
|
||||
})
|
||||
|
||||
stack = append(stack, children...)
|
||||
}
|
||||
|
||||
return updates, nil
|
||||
}, user.apiLabelsLock, user.updateChLock)
|
||||
@ -404,7 +493,7 @@ func (user *User) handleDeleteLabelEvent(ctx context.Context, event proton.Label
|
||||
}
|
||||
|
||||
// handleMessageEvents handles the given message events.
|
||||
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []proton.MessageEvent) error { //nolint:funlen
|
||||
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []proton.MessageEvent) error {
|
||||
for _, event := range messageEvents {
|
||||
ctx = logging.WithLogrusField(ctx, "messageID", event.ID)
|
||||
|
||||
@ -494,7 +583,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
"subject": logging.Sensitive(message.Subject),
|
||||
}).Info("Handling message created event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID)
|
||||
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
@ -509,7 +598,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
|
||||
var update imap.Update
|
||||
|
||||
if err := withAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res := buildRFC822(user.apiLabels, full, addrKR)
|
||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||
|
||||
if res.err != nil {
|
||||
user.log.WithError(err).Error("Failed to build RFC822 message")
|
||||
@ -553,7 +642,7 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, message proton.M
|
||||
Seen: message.Seen(),
|
||||
Flagged: message.Starred(),
|
||||
Draft: message.IsDraft(),
|
||||
Answered: message.IsReplied == true || message.IsRepliedAll == true, //nolint: gosimple
|
||||
Answered: message.IsRepliedAll == true || message.IsReplied == true, //nolint: gosimple
|
||||
},
|
||||
)
|
||||
|
||||
@ -586,7 +675,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
|
||||
"subject": logging.Sensitive(event.Message.Subject),
|
||||
}).Info("Handling draft updated event")
|
||||
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID)
|
||||
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
// If the message is not found, it means that it has been deleted before we could fetch it.
|
||||
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
|
||||
@ -600,7 +689,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
|
||||
var update imap.Update
|
||||
|
||||
if err := withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
|
||||
res := buildRFC822(user.apiLabels, full, addrKR)
|
||||
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
|
||||
|
||||
if res.err != nil {
|
||||
logrus.WithError(err).Error("Failed to build RFC822 message")
|
||||
@ -637,6 +726,20 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
func (user *User) handleUsedSpaceChange(usedSpace int) {
|
||||
safe.Lock(func() {
|
||||
if user.apiUser.UsedSpace == usedSpace {
|
||||
return
|
||||
}
|
||||
|
||||
user.apiUser.UsedSpace = usedSpace
|
||||
user.eventCh.Enqueue(events.UsedSpaceChanged{
|
||||
UserID: user.apiUser.ID,
|
||||
UsedSpace: usedSpace,
|
||||
})
|
||||
}, user.apiUserLock)
|
||||
}
|
||||
|
||||
func getMailboxName(label proton.Label) []string {
|
||||
var name []string
|
||||
|
||||
|
||||
@ -264,8 +264,6 @@ func (conn *imapConnector) DeleteMailbox(ctx context.Context, labelID imap.Mailb
|
||||
}
|
||||
|
||||
// CreateMessage creates a new message on the remote.
|
||||
//
|
||||
// nolint:funlen
|
||||
func (conn *imapConnector) CreateMessage(
|
||||
ctx context.Context,
|
||||
mailboxID imap.MailboxID,
|
||||
@ -292,7 +290,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
conn.log.WithField("messageID", messageID).Warn("Message already sent")
|
||||
|
||||
// Query the server-side message.
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID)
|
||||
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
@ -356,7 +354,7 @@ func (conn *imapConnector) CreateMessage(
|
||||
}
|
||||
|
||||
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id))
|
||||
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -382,7 +380,7 @@ func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.Messag
|
||||
func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if mailboxID == proton.AllMailLabel {
|
||||
if isAllMailOrScheduled(mailboxID) {
|
||||
return connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
@ -393,7 +391,7 @@ func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs
|
||||
func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
|
||||
defer conn.goPollAPIEvents(false)
|
||||
|
||||
if mailboxID == proton.AllMailLabel {
|
||||
if isAllMailOrScheduled(mailboxID) {
|
||||
return connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
@ -442,8 +440,8 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
|
||||
|
||||
if (labelFromID == proton.InboxLabel && labelToID == proton.SentLabel) ||
|
||||
(labelFromID == proton.SentLabel && labelToID == proton.InboxLabel) ||
|
||||
labelFromID == proton.AllMailLabel ||
|
||||
labelToID == proton.AllMailLabel {
|
||||
isAllMailOrScheduled(labelFromID) ||
|
||||
isAllMailOrScheduled(labelToID) {
|
||||
return false, connector.ErrOperationNotAllowed
|
||||
}
|
||||
|
||||
@ -507,19 +505,20 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update {
|
||||
}, conn.updateChLock)
|
||||
}
|
||||
|
||||
// GetUIDValidity returns the default UID validity for this user.
|
||||
func (conn *imapConnector) GetUIDValidity() imap.UID {
|
||||
return conn.vault.GetUIDValidity(conn.addrID)
|
||||
}
|
||||
// GetMailboxVisibility returns the visibility of a mailbox over IMAP.
|
||||
func (conn *imapConnector) GetMailboxVisibility(_ context.Context, mailboxID imap.MailboxID) imap.MailboxVisibility {
|
||||
switch mailboxID {
|
||||
case proton.AllMailLabel:
|
||||
if atomic.LoadUint32(&conn.showAllMail) != 0 {
|
||||
return imap.Visible
|
||||
}
|
||||
return imap.Hidden
|
||||
|
||||
// SetUIDValidity sets the default UID validity for this user.
|
||||
func (conn *imapConnector) SetUIDValidity(validity imap.UID) error {
|
||||
return conn.vault.SetUIDValidity(conn.addrID, validity)
|
||||
}
|
||||
|
||||
// IsMailboxVisible returns whether this mailbox should be visible over IMAP.
|
||||
func (conn *imapConnector) IsMailboxVisible(_ context.Context, mailboxID imap.MailboxID) bool {
|
||||
return atomic.LoadUint32(&conn.showAllMail) != 0 || mailboxID != proton.AllMailLabel
|
||||
case proton.AllScheduledLabel:
|
||||
return imap.HiddenIfEmpty
|
||||
default:
|
||||
return imap.Visible
|
||||
}
|
||||
}
|
||||
|
||||
// Close the connector will no longer be used and all resources should be closed/released.
|
||||
@ -550,7 +549,7 @@ func (conn *imapConnector) importMessage(
|
||||
|
||||
messageID = msg.ID
|
||||
} else {
|
||||
res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []proton.ImportReq{{
|
||||
str, err := conn.client.ImportMessages(ctx, addrKR, 1, 1, []proton.ImportReq{{
|
||||
Metadata: proton.ImportMetadata{
|
||||
AddressID: conn.addrID,
|
||||
LabelIDs: labelIDs,
|
||||
@ -558,7 +557,12 @@ func (conn *imapConnector) importMessage(
|
||||
Flags: flags,
|
||||
},
|
||||
Message: literal,
|
||||
}}...))
|
||||
}}...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to prepare message for import: %w", err)
|
||||
}
|
||||
|
||||
res, err := stream.Collect(ctx, str)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to import message: %w", err)
|
||||
}
|
||||
@ -568,7 +572,7 @@ func (conn *imapConnector) importMessage(
|
||||
|
||||
var err error
|
||||
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID); err != nil {
|
||||
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil {
|
||||
return fmt.Errorf("failed to fetch message: %w", err)
|
||||
}
|
||||
|
||||
@ -615,7 +619,7 @@ func toIMAPMessage(message proton.MessageMetadata) imap.Message {
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *imapConnector) createDraft(ctx context.Context, literal []byte, addrKR *crypto.KeyRing, sender proton.Address) (proton.Message, error) { //nolint:funlen
|
||||
func (conn *imapConnector) createDraft(ctx context.Context, literal []byte, addrKR *crypto.KeyRing, sender proton.Address) (proton.Message, error) {
|
||||
// Create a new message parser from the reader.
|
||||
parser, err := parser.New(bytes.NewReader(literal))
|
||||
if err != nil {
|
||||
@ -687,3 +691,7 @@ func toIMAPMailbox(label proton.Label, flags, permFlags, attrs imap.FlagSet) ima
|
||||
Attributes: attrs,
|
||||
}
|
||||
}
|
||||
|
||||
func isAllMailOrScheduled(mailboxID imap.MailboxID) bool {
|
||||
return (mailboxID == proton.AllMailLabel) || (mailboxID == proton.AllScheduledLabel)
|
||||
}
|
||||
|
||||
@ -218,8 +218,6 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) {
|
||||
// - the Content-Type header of each (leaf) part,
|
||||
// - the Content-Disposition header of each (leaf) part,
|
||||
// - the (decoded) body of each part.
|
||||
//
|
||||
// nolint:funlen
|
||||
func getMessageHash(b []byte) (string, error) {
|
||||
section := rfc822.Parse(b)
|
||||
|
||||
|
||||
@ -47,8 +47,6 @@ import (
|
||||
)
|
||||
|
||||
// sendMail sends an email from the given address to the given recipients.
|
||||
//
|
||||
// nolint:funlen
|
||||
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
|
||||
return safe.RLockRet(func() error {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@ -165,7 +163,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
|
||||
}
|
||||
|
||||
// sendWithKey sends the message with the given address key.
|
||||
func sendWithKey( //nolint:funlen
|
||||
func sendWithKey(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
sentry reporter.Reporter,
|
||||
@ -247,7 +245,7 @@ func sendWithKey( //nolint:funlen
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func getParentID( //nolint:funlen
|
||||
func getParentID(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
authAddrID string,
|
||||
@ -375,7 +373,6 @@ func createDraft(
|
||||
})
|
||||
}
|
||||
|
||||
// nolint:funlen
|
||||
func createAttachments(
|
||||
ctx context.Context,
|
||||
client *proton.Client,
|
||||
@ -468,12 +465,12 @@ func getRecipients(
|
||||
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
|
||||
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
|
||||
if err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get public keys: %w (%v)", err, recipient)
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)
|
||||
}
|
||||
|
||||
contactSettings, err := getContactSettings(ctx, client, userKR, recipient)
|
||||
if err != nil {
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get contact settings: %w", err)
|
||||
return proton.SendPreferences{}, fmt.Errorf("failed to get contact settings for %v: %w", recipient, err)
|
||||
}
|
||||
|
||||
return buildSendPrefs(contactSettings, settings, pubKeys, draft.MIMEType, recType == proton.RecipientTypeInternal)
|
||||
|
||||
@ -18,6 +18,7 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
@ -25,6 +26,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
@ -34,18 +36,38 @@ import (
|
||||
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pbnjay/memory"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
maxUpdateSize = 1 << 27 // 128 MiB
|
||||
maxBatchSize = 1 << 8 // 256
|
||||
)
|
||||
// syncSystemLabels ensures that system labels are all known to gluon.
|
||||
func (user *User) syncSystemLabels(ctx context.Context) error {
|
||||
return safe.RLockRet(func() error {
|
||||
var updates []imap.Update
|
||||
|
||||
// doSync begins syncing the users data.
|
||||
for _, label := range xslices.Filter(maps.Values(user.apiLabels), func(label proton.Label) bool { return label.Type == proton.LabelTypeSystem }) {
|
||||
if !wantLabel(label) {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
|
||||
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
|
||||
updateCh.Enqueue(update)
|
||||
updates = append(updates, update)
|
||||
}
|
||||
}
|
||||
|
||||
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
|
||||
return fmt.Errorf("could not sync system labels: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
|
||||
}
|
||||
|
||||
// doSync begins syncing the user's data.
|
||||
// It first ensures the latest event ID is known; if not, it fetches it.
|
||||
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
|
||||
// depending on whether the sync was successful.
|
||||
@ -89,7 +111,6 @@ func (user *User) doSync(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:funlen
|
||||
func (user *User) sync(ctx context.Context) error {
|
||||
return safe.RLockRet(func() error {
|
||||
return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||
@ -143,7 +164,7 @@ func (user *User) sync(ctx context.Context) error {
|
||||
addrKRs,
|
||||
user.updateCh,
|
||||
user.eventCh,
|
||||
user.syncWorkers,
|
||||
user.maxSyncMemory,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
@ -166,7 +187,7 @@ func (user *User) sync(ctx context.Context) error {
|
||||
func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh ...*queue.QueuedChannel[imap.Update]) error {
|
||||
var updates []imap.Update
|
||||
|
||||
// Create placeholder Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
|
||||
// Create placeholder Folders/Labels mailboxes with the \Noselect attribute.
|
||||
for _, prefix := range []string{folderPrefix, labelPrefix} {
|
||||
for _, updateCh := range updateCh {
|
||||
update := newPlaceHolderMailboxCreatedUpdate(prefix)
|
||||
@ -212,7 +233,15 @@ func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:funlen
|
||||
const Kilobyte = uint64(1024)
|
||||
const Megabyte = 1024 * Kilobyte
|
||||
const Gigabyte = 1024 * Megabyte
|
||||
|
||||
func toMB(v uint64) float64 {
|
||||
return float64(v) / float64(Megabyte)
|
||||
}
|
||||
|
||||
// nolint:gocyclo
|
||||
func syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
@ -224,7 +253,7 @@ func syncMessages(
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update],
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
syncWorkers int,
|
||||
maxSyncMemory uint64,
|
||||
) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
@ -235,78 +264,296 @@ func syncMessages(
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"messages": len(messageIDs),
|
||||
"workers": syncWorkers,
|
||||
"numCPU": runtime.NumCPU(),
|
||||
}).Info("Starting message sync")
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher, len(updateCh))
|
||||
|
||||
for addrID, updateCh := range updateCh {
|
||||
flushers[addrID] = newFlusher(updateCh, maxUpdateSize)
|
||||
}
|
||||
|
||||
// Create a reporter to report sync progress updates.
|
||||
syncReporter := newSyncReporter(userID, eventCh, len(messageIDs), time.Second)
|
||||
defer syncReporter.done()
|
||||
|
||||
type flushUpdate struct {
|
||||
messageID string
|
||||
pushedUpdates []imap.Update
|
||||
batchLen int
|
||||
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
|
||||
// times x due to pipeline and all additional memory used by network requests and compression+io.
|
||||
|
||||
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
|
||||
// returns as we can't keep the pipeline fed fast enough.
|
||||
const MaxDownloadRequestMem = 128 * Megabyte
|
||||
|
||||
// Any lower than this and we may fail to download messages.
|
||||
const MinDownloadRequestMem = 40 * Megabyte
|
||||
|
||||
// This value can be increased to your hearts content. The more system memory the user has, the more messages
|
||||
// we can build in parallel.
|
||||
const MaxMessageBuildingMem = 128 * Megabyte
|
||||
const MinMessageBuildingMem = 64 * Megabyte
|
||||
|
||||
// Maximum recommend value for parallel downloads by the API team.
|
||||
const maxParallelDownloads = 20
|
||||
|
||||
totalMemory := memory.TotalMemory()
|
||||
|
||||
if maxSyncMemory >= totalMemory/2 {
|
||||
logrus.Warnf("Requested max sync memory of %v MB is greater than half of system memory (%v MB), forcing to half of system memory",
|
||||
maxSyncMemory, toMB(totalMemory/2))
|
||||
maxSyncMemory = totalMemory / 2
|
||||
}
|
||||
|
||||
if maxSyncMemory < 800*Megabyte {
|
||||
logrus.Warnf("Requested max sync memory of %v MB, but minimum recommended is 800 MB, forcing max syncMemory to 800MB", toMB(maxSyncMemory))
|
||||
maxSyncMemory = 800 * Megabyte
|
||||
}
|
||||
|
||||
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
|
||||
|
||||
syncMaxDownloadRequestMem := MaxDownloadRequestMem
|
||||
syncMaxMessageBuildingMem := MaxMessageBuildingMem
|
||||
|
||||
// If less than 2GB available try and limit max memory to 512 MB
|
||||
switch {
|
||||
case maxSyncMemory < 2*Gigabyte:
|
||||
if maxSyncMemory < 800*Megabyte {
|
||||
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
|
||||
}
|
||||
syncMaxDownloadRequestMem = MinDownloadRequestMem
|
||||
syncMaxMessageBuildingMem = MinMessageBuildingMem
|
||||
case maxSyncMemory == 2*Gigabyte:
|
||||
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
|
||||
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
|
||||
// updates. Additionally, most of ot sync time is spent in the message building.
|
||||
syncMaxDownloadRequestMem = MaxDownloadRequestMem
|
||||
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
|
||||
syncMaxMessageBuildingMem = MaxMessageBuildingMem
|
||||
default:
|
||||
// Divide by 8 as download stage and build stage will use aprox. 4x the specified memory.
|
||||
remainingMemory := (maxSyncMemory - 2*Gigabyte) / 8
|
||||
syncMaxDownloadRequestMem = MaxDownloadRequestMem + remainingMemory
|
||||
syncMaxMessageBuildingMem = MaxMessageBuildingMem + remainingMemory
|
||||
}
|
||||
|
||||
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
|
||||
toMB(syncMaxDownloadRequestMem),
|
||||
toMB(syncMaxMessageBuildingMem),
|
||||
toMB((syncMaxMessageBuildingMem*4)+(syncMaxDownloadRequestMem*4)),
|
||||
)
|
||||
|
||||
type flushUpdate struct {
|
||||
messageID string
|
||||
err error
|
||||
batchLen int
|
||||
}
|
||||
|
||||
type downloadRequest struct {
|
||||
ids []string
|
||||
expectedSize uint64
|
||||
err error
|
||||
}
|
||||
|
||||
type downloadedMessageBatch struct {
|
||||
batch []proton.FullMessage
|
||||
}
|
||||
|
||||
type builtMessageBatch struct {
|
||||
batch []*buildRes
|
||||
}
|
||||
|
||||
downloadCh := make(chan downloadRequest)
|
||||
|
||||
buildCh := make(chan downloadedMessageBatch)
|
||||
|
||||
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
|
||||
// to the update flushing goroutine.
|
||||
flushCh := make(chan []*buildRes, 2)
|
||||
flushCh := make(chan builtMessageBatch)
|
||||
|
||||
// Allow up to 4 batched wait requests.
|
||||
flushUpdateCh := make(chan flushUpdate, 4)
|
||||
flushUpdateCh := make(chan flushUpdate)
|
||||
|
||||
errorCh := make(chan error, syncWorkers)
|
||||
errorCh := make(chan error, maxParallelDownloads*4)
|
||||
|
||||
// Go routine in charge of downloading message metadata
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(downloadCh)
|
||||
const MetadataDataPageSize = 150
|
||||
|
||||
var downloadReq downloadRequest
|
||||
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
|
||||
|
||||
metadataChunks := xslices.Chunk(messageIDs, MetadataDataPageSize)
|
||||
for i, metadataChunk := range metadataChunks {
|
||||
logrus.Debugf("Metadata Request (%v of %v), previous: %v", i, len(metadataChunks), len(downloadReq.ids))
|
||||
metadata, err := client.GetMessageMetadataPage(ctx, 0, len(metadataChunk), proton.MessageFilter{ID: metadataChunk})
|
||||
if err != nil {
|
||||
downloadReq.err = err
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
|
||||
// Build look up table so that messages are processed in the same order.
|
||||
metadataMap := make(map[string]int, len(metadata))
|
||||
for i, v := range metadata {
|
||||
metadataMap[v.ID] = i
|
||||
}
|
||||
|
||||
for i, id := range metadataChunk {
|
||||
m := &metadata[metadataMap[id]]
|
||||
nextSize := downloadReq.expectedSize + uint64(m.Size)
|
||||
if nextSize >= syncMaxDownloadRequestMem || len(downloadReq.ids) >= 256 {
|
||||
logrus.Debugf("Download Request Sent at %v of %v", i, len(metadata))
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
downloadReq.expectedSize = 0
|
||||
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
|
||||
nextSize = uint64(m.Size)
|
||||
}
|
||||
downloadReq.ids = append(downloadReq.ids, id)
|
||||
downloadReq.expectedSize = nextSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(downloadReq.ids) != 0 {
|
||||
logrus.Debugf("Sending remaining download request")
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "meta-data"})
|
||||
|
||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||
go func() {
|
||||
defer close(flushCh)
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(buildCh)
|
||||
defer close(errorCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync downloader exit")
|
||||
}()
|
||||
|
||||
attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads)
|
||||
defer attachmentDownloader.close()
|
||||
|
||||
for request := range downloadCh {
|
||||
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
|
||||
if request.err != nil {
|
||||
errorCh <- request.err
|
||||
return
|
||||
}
|
||||
|
||||
for _, batch := range xslices.Chunk(messageIDs, maxBatchSize) {
|
||||
if ctx.Err() != nil {
|
||||
errorCh <- ctx.Err()
|
||||
return
|
||||
}
|
||||
|
||||
result, err := parallel.MapContext(ctx, syncWorkers, batch, func(ctx context.Context, id string) (*buildRes, error) {
|
||||
msg, err := client.GetFullMessage(ctx, id)
|
||||
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
|
||||
var result proton.FullMessage
|
||||
|
||||
msg, err := client.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
|
||||
if err != nil {
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID]), nil
|
||||
result.Message = msg
|
||||
result.AttData = attachments
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case buildCh <- downloadedMessageBatch{
|
||||
batch: result,
|
||||
}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "download"})
|
||||
|
||||
// Goroutine which builds messages after they have been downloaded
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(flushCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync builder exit")
|
||||
}()
|
||||
|
||||
maxMessagesInParallel := runtime.NumCPU()
|
||||
|
||||
for buildBatch := range buildCh {
|
||||
if ctx.Err() != nil {
|
||||
errorCh <- ctx.Err()
|
||||
return
|
||||
}
|
||||
|
||||
flushCh <- result
|
||||
chunks := chunkSyncBuilderBatch(buildBatch.batch, syncMaxMessageBuildingMem)
|
||||
|
||||
for index, chunk := range chunks {
|
||||
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
|
||||
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case flushCh <- builtMessageBatch{result}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, logging.Labels{"sync-stage": "builder"})
|
||||
|
||||
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
|
||||
go func() {
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(flushUpdateCh)
|
||||
for batch := range flushCh {
|
||||
for _, res := range batch {
|
||||
defer func() {
|
||||
logrus.Debugf("sync flush exit")
|
||||
}()
|
||||
|
||||
type updateTargetInfo struct {
|
||||
queueIndex int
|
||||
ch *queue.QueuedChannel[imap.Update]
|
||||
}
|
||||
|
||||
pendingUpdates := make([][]*imap.MessageCreated, len(updateCh))
|
||||
addressToIndex := make(map[string]updateTargetInfo)
|
||||
|
||||
{
|
||||
i := 0
|
||||
for addrID, updateCh := range updateCh {
|
||||
addressToIndex[addrID] = updateTargetInfo{
|
||||
ch: updateCh,
|
||||
queueIndex: i,
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
for downloadBatch := range flushCh {
|
||||
logrus.Debugf("Flush batch: %v", len(downloadBatch.batch))
|
||||
for _, res := range downloadBatch.batch {
|
||||
if res.err != nil {
|
||||
if err := vault.AddFailedMessageID(res.messageID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to add failed message ID")
|
||||
@ -327,31 +574,38 @@ func syncMessages(
|
||||
}
|
||||
}
|
||||
|
||||
flushers[res.addressID].push(res.update)
|
||||
targetInfo := addressToIndex[res.addressID]
|
||||
pendingUpdates[targetInfo.queueIndex] = append(pendingUpdates[targetInfo.queueIndex], res.update)
|
||||
}
|
||||
|
||||
var pushedUpdates []imap.Update
|
||||
for _, flusher := range flushers {
|
||||
flusher.flush()
|
||||
pushedUpdates = append(pushedUpdates, flusher.collectPushedUpdates()...)
|
||||
for _, info := range addressToIndex {
|
||||
up := imap.NewMessagesCreated(true, pendingUpdates[info.queueIndex]...)
|
||||
info.ch.Enqueue(up)
|
||||
|
||||
err, ok := up.WaitContext(ctx)
|
||||
if ok && err != nil {
|
||||
flushUpdateCh <- flushUpdate{
|
||||
err: fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
pendingUpdates[info.queueIndex] = pendingUpdates[info.queueIndex][:0]
|
||||
}
|
||||
|
||||
flushUpdateCh <- flushUpdate{
|
||||
messageID: batch[0].messageID,
|
||||
pushedUpdates: pushedUpdates,
|
||||
batchLen: len(batch),
|
||||
select {
|
||||
case flushUpdateCh <- flushUpdate{
|
||||
messageID: downloadBatch.batch[0].messageID,
|
||||
err: nil,
|
||||
batchLen: len(downloadBatch.batch),
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, logging.Labels{"sync-stage": "flush"})
|
||||
|
||||
for flushUpdate := range flushUpdateCh {
|
||||
for _, up := range flushUpdate.pushedUpdates {
|
||||
err, ok := up.WaitContext(ctx)
|
||||
if ok && err != nil {
|
||||
return fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
@ -394,6 +648,9 @@ func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *im
|
||||
|
||||
case proton.StarredLabel:
|
||||
attrs = attrs.Add(imap.AttrFlagged)
|
||||
|
||||
case proton.AllScheduledLabel:
|
||||
labelName = "Scheduled" // API actual name is "All Scheduled"
|
||||
}
|
||||
|
||||
return imap.NewMailboxCreated(imap.Mailbox{
|
||||
@ -407,7 +664,7 @@ func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *im
|
||||
|
||||
func newPlaceHolderMailboxCreatedUpdate(labelName string) *imap.MailboxCreated {
|
||||
return imap.NewMailboxCreated(imap.Mailbox{
|
||||
ID: imap.MailboxID(uuid.NewString()),
|
||||
ID: imap.MailboxID(labelName),
|
||||
Name: []string{labelName},
|
||||
Flags: defaultFlags,
|
||||
PermanentFlags: defaultPermanentFlags,
|
||||
@ -456,6 +713,9 @@ func wantLabel(label proton.Label) bool {
|
||||
case proton.StarredLabel:
|
||||
return true
|
||||
|
||||
case proton.AllScheduledLabel:
|
||||
return true
|
||||
|
||||
default:
|
||||
return false
|
||||
}
|
||||
@ -471,3 +731,126 @@ func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string {
|
||||
return wantLabel(apiLabel)
|
||||
})
|
||||
}
|
||||
|
||||
type attachmentResult struct {
|
||||
attachment []byte
|
||||
err error
|
||||
}
|
||||
|
||||
type attachmentJob struct {
|
||||
id string
|
||||
size int64
|
||||
result chan attachmentResult
|
||||
}
|
||||
|
||||
type attachmentDownloader struct {
|
||||
workerCh chan attachmentJob
|
||||
cancel context.CancelFunc
|
||||
}
|
||||
|
||||
func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case job, ok := <-work:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
var b bytes.Buffer
|
||||
b.Grow(int(job.size))
|
||||
err := client.GetAttachmentInto(ctx, job.id, &b)
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
close(job.result)
|
||||
return
|
||||
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
|
||||
close(job.result)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
|
||||
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
for i := 0; i < workerCount; i++ {
|
||||
workerCh = make(chan attachmentJob)
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
|
||||
"sync": fmt.Sprintf("att-downloader %v", i),
|
||||
})
|
||||
}
|
||||
|
||||
return &attachmentDownloader{
|
||||
workerCh: workerCh,
|
||||
cancel: cancel,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
|
||||
resultChs := make([]chan attachmentResult, len(attachments))
|
||||
for i, id := range attachments {
|
||||
resultChs[i] = make(chan attachmentResult, 1)
|
||||
select {
|
||||
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
result := make([][]byte, len(attachments))
|
||||
var err error
|
||||
for i := 0; i < len(attachments); i++ {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case r := <-resultChs[i]:
|
||||
if r.err != nil {
|
||||
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
|
||||
}
|
||||
result[i] = r.attachment
|
||||
}
|
||||
}
|
||||
|
||||
return result, err
|
||||
}
|
||||
|
||||
func (a *attachmentDownloader) close() {
|
||||
a.cancel()
|
||||
}
|
||||
|
||||
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
|
||||
var expectedMemUsage uint64
|
||||
var chunks [][]proton.FullMessage
|
||||
var lastIndex int
|
||||
var index int
|
||||
|
||||
for _, v := range batch {
|
||||
var dataSize uint64
|
||||
for _, a := range v.Attachments {
|
||||
dataSize += uint64(a.Size)
|
||||
}
|
||||
|
||||
// 2x increase for attachment due to extra memory needed for decrypting and writing
|
||||
// in memory buffer.
|
||||
dataSize *= 2
|
||||
dataSize += uint64(len(v.Body))
|
||||
|
||||
nextMemSize := expectedMemUsage + dataSize
|
||||
if nextMemSize >= maxMemory {
|
||||
chunks = append(chunks, batch[lastIndex:index])
|
||||
lastIndex = index
|
||||
expectedMemUsage = dataSize
|
||||
} else {
|
||||
expectedMemUsage = nextMemSize
|
||||
}
|
||||
|
||||
index++
|
||||
}
|
||||
|
||||
if lastIndex < len(batch) {
|
||||
chunks = append(chunks, batch[lastIndex:])
|
||||
}
|
||||
|
||||
return chunks
|
||||
}
|
||||
|
||||
@ -48,16 +48,18 @@ func defaultJobOpts() message.JobOptions {
|
||||
}
|
||||
}
|
||||
|
||||
func buildRFC822(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing) *buildRes {
|
||||
func buildRFC822(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing, buffer *bytes.Buffer) *buildRes {
|
||||
var (
|
||||
update *imap.MessageCreated
|
||||
err error
|
||||
)
|
||||
|
||||
if literal, buildErr := message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); buildErr != nil {
|
||||
buffer.Grow(full.Size)
|
||||
|
||||
if buildErr := message.BuildRFC822Into(addrKR, full.Message, full.AttData, defaultJobOpts(), buffer); buildErr != nil {
|
||||
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, buildErr)
|
||||
err = buildErr
|
||||
} else if created, parseErr := newMessageCreatedUpdate(apiLabels, full.MessageMetadata, literal); parseErr != nil {
|
||||
} else if created, parseErr := newMessageCreatedUpdate(apiLabels, full.MessageMetadata, buffer.Bytes()); parseErr != nil {
|
||||
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, parseErr)
|
||||
err = parseErr
|
||||
} else {
|
||||
|
||||
@ -24,6 +24,8 @@ import (
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/rfc822"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
@ -47,3 +49,32 @@ func TestNewFailedMessageLiteral(t *testing.T) {
|
||||
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2)`, parsed.Body)
|
||||
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2 NIL NIL NIL NIL)`, parsed.Structure)
|
||||
}
|
||||
|
||||
func TestSyncChunkSyncBuilderBatch(t *testing.T) {
|
||||
// GODT-2424 - Some messages were not fully built due to a bug in the chunking if the total memory used by the
|
||||
// message would be higher than the maximum we allowed.
|
||||
const totalMessageCount = 100
|
||||
|
||||
msg := proton.FullMessage{
|
||||
Message: proton.Message{
|
||||
Attachments: []proton.Attachment{
|
||||
{
|
||||
Size: int64(8 * Megabyte),
|
||||
},
|
||||
},
|
||||
},
|
||||
AttData: nil,
|
||||
}
|
||||
|
||||
messages := xslices.Repeat(msg, totalMessageCount)
|
||||
|
||||
chunks := chunkSyncBuilderBatch(messages, 16*Megabyte)
|
||||
|
||||
var totalMessagesInChunks int
|
||||
|
||||
for _, v := range chunks {
|
||||
totalMessagesInChunks += len(v)
|
||||
}
|
||||
|
||||
require.Equal(t, totalMessagesInChunks, totalMessageCount)
|
||||
}
|
||||
|
||||
@ -1,63 +0,0 @@
|
||||
// Copyright (c) 2023 Proton AG
|
||||
//
|
||||
// This file is part of Proton Mail Bridge.
|
||||
//
|
||||
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
||||
// it under the terms of the GNU General Public License as published by
|
||||
// the Free Software Foundation, either version 3 of the License, or
|
||||
// (at your option) any later version.
|
||||
//
|
||||
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
||||
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
// GNU General Public License for more details.
|
||||
//
|
||||
// You should have received a copy of the GNU General Public License
|
||||
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
||||
|
||||
package user
|
||||
|
||||
import (
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
)
|
||||
|
||||
type flusher struct {
|
||||
updateCh *queue.QueuedChannel[imap.Update]
|
||||
updates []*imap.MessageCreated
|
||||
pushedUpdates []imap.Update
|
||||
|
||||
maxUpdateSize int
|
||||
curChunkSize int
|
||||
}
|
||||
|
||||
func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher {
|
||||
return &flusher{
|
||||
updateCh: updateCh,
|
||||
maxUpdateSize: maxUpdateSize,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) push(update *imap.MessageCreated) {
|
||||
f.updates = append(f.updates, update)
|
||||
|
||||
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize {
|
||||
f.flush()
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) flush() {
|
||||
if len(f.updates) > 0 {
|
||||
update := imap.NewMessagesCreated(true, f.updates...)
|
||||
f.updateCh.Enqueue(update)
|
||||
f.updates = nil
|
||||
f.curChunkSize = 0
|
||||
f.pushedUpdates = append(f.pushedUpdates, update)
|
||||
}
|
||||
}
|
||||
|
||||
func (f *flusher) collectPushedUpdates() []imap.Update {
|
||||
updates := f.pushedUpdates
|
||||
f.pushedUpdates = nil
|
||||
return updates
|
||||
}
|
||||
@ -20,6 +20,7 @@ package user
|
||||
import (
|
||||
"fmt"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
@ -91,3 +92,7 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
|
||||
|
||||
return sorted
|
||||
}
|
||||
|
||||
func newProtonAPIScheduler() proton.Scheduler {
|
||||
return proton.NewParallelScheduler(runtime.NumCPU() / 2)
|
||||
}
|
||||
|
||||
@ -88,13 +88,12 @@ type User struct {
|
||||
pollAPIEventsCh chan chan struct{}
|
||||
goPollAPIEvents func(wait bool)
|
||||
|
||||
syncWorkers int
|
||||
showAllMail uint32
|
||||
|
||||
maxSyncMemory uint64
|
||||
}
|
||||
|
||||
// New returns a new user.
|
||||
//
|
||||
// nolint:funlen
|
||||
func New(
|
||||
ctx context.Context,
|
||||
encVault *vault.User,
|
||||
@ -102,9 +101,9 @@ func New(
|
||||
reporter reporter.Reporter,
|
||||
apiUser proton.User,
|
||||
crashHandler async.PanicHandler,
|
||||
syncWorkers int,
|
||||
showAllMail bool,
|
||||
) (*User, error) { //nolint:funlen
|
||||
maxSyncMemory uint64,
|
||||
) (*User, error) {
|
||||
logrus.WithField("userID", apiUser.ID).Info("Creating new user")
|
||||
|
||||
// Get the user's API addresses.
|
||||
@ -146,8 +145,9 @@ func New(
|
||||
tasks: async.NewGroup(context.Background(), crashHandler),
|
||||
pollAPIEventsCh: make(chan chan struct{}),
|
||||
|
||||
syncWorkers: syncWorkers,
|
||||
showAllMail: b32(showAllMail),
|
||||
|
||||
maxSyncMemory: maxSyncMemory,
|
||||
}
|
||||
|
||||
// Initialize the user's update channels for its current address mode.
|
||||
@ -193,7 +193,15 @@ func New(
|
||||
// Sync the user.
|
||||
user.syncAbort.Do(ctx, func(ctx context.Context) {
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
user.log.Info("Sync already complete, skipping")
|
||||
user.log.Info("Sync already complete, only system label will be updated")
|
||||
|
||||
if err := user.syncSystemLabels(ctx); err != nil {
|
||||
user.log.WithError(err).Error("Failed to update system labels")
|
||||
return
|
||||
}
|
||||
|
||||
user.log.Info("System label update complete, starting API event stream")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
@ -257,11 +265,13 @@ func (user *User) Match(query string) bool {
|
||||
}, user.apiUserLock, user.apiAddrsLock)
|
||||
}
|
||||
|
||||
// Emails returns all the user's email addresses.
|
||||
// Emails returns all the user's active email addresses.
|
||||
// It returns them in sorted order; the user's primary address is first.
|
||||
func (user *User) Emails() []string {
|
||||
return safe.RLockRet(func() []string {
|
||||
addresses := maps.Values(user.apiAddrs)
|
||||
addresses := xslices.Filter(maps.Values(user.apiAddrs), func(addr proton.Address) bool {
|
||||
return addr.Status == proton.AddressStatusEnabled
|
||||
})
|
||||
|
||||
slices.SortFunc(addresses, func(a, b proton.Address) bool {
|
||||
return a.Order < b.Order
|
||||
@ -432,8 +442,6 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
|
||||
}
|
||||
|
||||
// SendMail sends an email from the given address to the given recipients.
|
||||
//
|
||||
// nolint:funlen
|
||||
func (user *User) SendMail(authID string, from string, to []string, r io.Reader) error {
|
||||
if user.vault.SyncStatus().IsComplete() {
|
||||
defer user.goPollAPIEvents(true)
|
||||
@ -636,13 +644,11 @@ func (user *User) startEvents(ctx context.Context) {
|
||||
}
|
||||
|
||||
// doEventPoll is called whenever API events should be polled.
|
||||
//
|
||||
//nolint:funlen
|
||||
func (user *User) doEventPoll(ctx context.Context) error {
|
||||
user.eventLock.Lock()
|
||||
defer user.eventLock.Unlock()
|
||||
|
||||
event, err := user.client.GetEvent(ctx, user.vault.EventID())
|
||||
event, _, err := user.client.GetEvent(ctx, user.vault.EventID())
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get event (caused by %T): %w", internal.ErrCause(err), err)
|
||||
}
|
||||
|
||||
@ -119,14 +119,14 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
|
||||
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
|
||||
require.NoError(tb, err)
|
||||
|
||||
vault, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
|
||||
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
|
||||
require.NoError(tb, err)
|
||||
require.False(tb, corrupt)
|
||||
|
||||
vaultUser, err := vault.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
|
||||
vaultUser, err := v.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
|
||||
require.NoError(tb, err)
|
||||
|
||||
user, err := New(ctx, vaultUser, client, nil, apiUser, nil, vault.SyncWorkers(), true)
|
||||
user, err := New(ctx, vaultUser, client, nil, apiUser, nil, true, vault.DefaultMaxSyncMemory)
|
||||
require.NoError(tb, err)
|
||||
defer user.Close()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user