chore: merge release/perth_narrows into devel

This commit is contained in:
Jakub
2023-03-13 11:39:04 +01:00
159 changed files with 8308 additions and 1899 deletions

View File

@ -18,6 +18,7 @@
package user
import (
"bytes"
"context"
"errors"
"fmt"
@ -67,6 +68,10 @@ func (user *User) handleAPIEvent(ctx context.Context, event proton.Event) error
}
}
if event.UsedSpace != nil {
user.handleUsedSpaceChange(*event.UsedSpace)
}
return nil
}
@ -194,6 +199,12 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
user.apiAddrs[event.Address.ID] = event.Address
// If the address is disabled.
if event.Address.Status != proton.AddressStatusEnabled {
return nil
}
// If the address is enabled, we need to hook it up to the update channels.
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := getAddrIdx(user.apiAddrs, 0)
@ -220,6 +231,10 @@ func (user *User) handleCreateAddressEvent(ctx context.Context, event proton.Add
// Perform the sync in an RLock.
return safe.RLockRet(func() error {
if event.Address.Status != proton.AddressStatusEnabled {
return nil
}
if user.vault.AddressMode() == vault.SplitMode {
if err := syncLabels(ctx, user.apiLabels, user.updateCh[event.Address.ID]); err != nil {
return fmt.Errorf("failed to sync labels to new address: %w", err)
@ -237,18 +252,58 @@ func (user *User) handleUpdateAddressEvent(_ context.Context, event proton.Addre
"email": logging.Sensitive(event.Address.Email),
}).Info("Handling address updated event")
if _, ok := user.apiAddrs[event.Address.ID]; !ok {
oldAddr, ok := user.apiAddrs[event.Address.ID]
if !ok {
user.log.Debugf("Address %q does not exist", event.Address.ID)
return nil
}
user.apiAddrs[event.Address.ID] = event.Address
user.eventCh.Enqueue(events.UserAddressUpdated{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
switch {
// If the address was newly enabled:
case oldAddr.Status != proton.AddressStatusEnabled && event.Address.Status == proton.AddressStatusEnabled:
switch user.vault.AddressMode() {
case vault.CombinedMode:
primAddr, err := getAddrIdx(user.apiAddrs, 0)
if err != nil {
return fmt.Errorf("failed to get primary address: %w", err)
}
user.updateCh[event.Address.ID] = user.updateCh[primAddr.ID]
case vault.SplitMode:
user.updateCh[event.Address.ID] = queue.NewQueuedChannel[imap.Update](0, 0)
}
user.eventCh.Enqueue(events.UserAddressEnabled{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
// If the address was newly disabled:
case oldAddr.Status == proton.AddressStatusEnabled && event.Address.Status != proton.AddressStatusEnabled:
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].CloseAndDiscardQueued()
}
delete(user.updateCh, event.ID)
user.eventCh.Enqueue(events.UserAddressDisabled{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
// Otherwise it's just an update:
default:
user.eventCh.Enqueue(events.UserAddressUpdated{
UserID: user.apiUser.ID,
AddressID: event.Address.ID,
Email: event.Address.Email,
})
}
return nil
}, user.apiAddrsLock, user.updateChLock)
@ -264,12 +319,20 @@ func (user *User) handleDeleteAddressEvent(_ context.Context, event proton.Addre
return nil
}
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].CloseAndDiscardQueued()
delete(user.updateCh, event.ID)
delete(user.apiAddrs, event.ID)
// If the address was disabled to begin with, we don't need to do anything.
if addr.Status != proton.AddressStatusEnabled {
return nil
}
delete(user.apiAddrs, event.ID)
// Otherwise, in split mode, drop the update queue.
if user.vault.AddressMode() == vault.SplitMode {
user.updateCh[event.ID].CloseAndDiscardQueued()
}
// And in either mode, remove the address from the update channel map.
delete(user.updateCh, event.ID)
user.eventCh.Enqueue(events.UserAddressDeleted{
UserID: user.apiUser.ID,
@ -356,25 +419,51 @@ func (user *User) handleUpdateLabelEvent(ctx context.Context, event proton.Label
"name": logging.Sensitive(event.Label.Name),
}).Info("Handling label updated event")
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
if _, ok := user.apiLabels[event.Label.ID]; ok {
user.apiLabels[event.Label.ID] = event.Label
}
stack := []proton.Label{event.Label}
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := imap.NewMailboxUpdated(
imap.MailboxID(event.ID),
getMailboxName(event.Label),
)
updateCh.Enqueue(update)
updates = append(updates, update)
}
for len(stack) > 0 {
label := stack[0]
stack = stack[1:]
user.eventCh.Enqueue(events.UserLabelUpdated{
UserID: user.apiUser.ID,
LabelID: event.Label.ID,
Name: event.Label.Name,
})
// Only update the label if it exists; we don't want to create it as a client may have just deleted it.
if _, ok := user.apiLabels[label.ID]; ok {
user.apiLabels[label.ID] = event.Label
}
// API doesn't notify us that the path has changed. We need to fetch it again.
apiLabel, err := user.client.GetLabel(ctx, label.ID, label.Type)
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
user.log.WithError(apiErr).Warn("Failed to get label: label does not exist")
continue
} else if err != nil {
return nil, fmt.Errorf("failed to get label %q: %w", label.ID, err)
}
// Update the label in the map.
user.apiLabels[apiLabel.ID] = apiLabel
// Notify the IMAP clients.
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := imap.NewMailboxUpdated(
imap.MailboxID(apiLabel.ID),
getMailboxName(apiLabel),
)
updateCh.Enqueue(update)
updates = append(updates, update)
}
user.eventCh.Enqueue(events.UserLabelUpdated{
UserID: user.apiUser.ID,
LabelID: apiLabel.ID,
Name: apiLabel.Name,
})
children := xslices.Filter(maps.Values(user.apiLabels), func(other proton.Label) bool {
return other.ParentID == label.ID
})
stack = append(stack, children...)
}
return updates, nil
}, user.apiLabelsLock, user.updateChLock)
@ -404,7 +493,7 @@ func (user *User) handleDeleteLabelEvent(ctx context.Context, event proton.Label
}
// handleMessageEvents handles the given message events.
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []proton.MessageEvent) error { //nolint:funlen
func (user *User) handleMessageEvents(ctx context.Context, messageEvents []proton.MessageEvent) error {
for _, event := range messageEvents {
ctx = logging.WithLogrusField(ctx, "messageID", event.ID)
@ -494,7 +583,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
"subject": logging.Sensitive(message.Subject),
}).Info("Handling message created event")
full, err := user.client.GetFullMessage(ctx, message.ID)
full, err := user.client.GetFullMessage(ctx, message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
@ -509,7 +598,7 @@ func (user *User) handleCreateMessageEvent(ctx context.Context, message proton.M
var update imap.Update
if err := withAddrKR(user.apiUser, user.apiAddrs[message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
res := buildRFC822(user.apiLabels, full, addrKR)
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
if res.err != nil {
user.log.WithError(err).Error("Failed to build RFC822 message")
@ -553,7 +642,7 @@ func (user *User) handleUpdateMessageEvent(ctx context.Context, message proton.M
Seen: message.Seen(),
Flagged: message.Starred(),
Draft: message.IsDraft(),
Answered: message.IsReplied == true || message.IsRepliedAll == true, //nolint: gosimple
Answered: message.IsRepliedAll == true || message.IsReplied == true, //nolint: gosimple
},
)
@ -586,7 +675,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
"subject": logging.Sensitive(event.Message.Subject),
}).Info("Handling draft updated event")
full, err := user.client.GetFullMessage(ctx, event.Message.ID)
full, err := user.client.GetFullMessage(ctx, event.Message.ID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
if err != nil {
// If the message is not found, it means that it has been deleted before we could fetch it.
if apiErr := new(proton.APIError); errors.As(err, &apiErr) && apiErr.Status == http.StatusUnprocessableEntity {
@ -600,7 +689,7 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
var update imap.Update
if err := withAddrKR(user.apiUser, user.apiAddrs[event.Message.AddressID], user.vault.KeyPass(), func(_, addrKR *crypto.KeyRing) error {
res := buildRFC822(user.apiLabels, full, addrKR)
res := buildRFC822(user.apiLabels, full, addrKR, new(bytes.Buffer))
if res.err != nil {
logrus.WithError(err).Error("Failed to build RFC822 message")
@ -637,6 +726,20 @@ func (user *User) handleUpdateDraftEvent(ctx context.Context, event proton.Messa
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
func (user *User) handleUsedSpaceChange(usedSpace int) {
safe.Lock(func() {
if user.apiUser.UsedSpace == usedSpace {
return
}
user.apiUser.UsedSpace = usedSpace
user.eventCh.Enqueue(events.UsedSpaceChanged{
UserID: user.apiUser.ID,
UsedSpace: usedSpace,
})
}, user.apiUserLock)
}
func getMailboxName(label proton.Label) []string {
var name []string

View File

@ -264,8 +264,6 @@ func (conn *imapConnector) DeleteMailbox(ctx context.Context, labelID imap.Mailb
}
// CreateMessage creates a new message on the remote.
//
// nolint:funlen
func (conn *imapConnector) CreateMessage(
ctx context.Context,
mailboxID imap.MailboxID,
@ -292,7 +290,7 @@ func (conn *imapConnector) CreateMessage(
conn.log.WithField("messageID", messageID).Warn("Message already sent")
// Query the server-side message.
full, err := conn.client.GetFullMessage(ctx, messageID)
full, err := conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
if err != nil {
return imap.Message{}, nil, fmt.Errorf("failed to fetch message: %w", err)
}
@ -356,7 +354,7 @@ func (conn *imapConnector) CreateMessage(
}
func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.MessageID) ([]byte, error) {
msg, err := conn.client.GetFullMessage(ctx, string(id))
msg, err := conn.client.GetFullMessage(ctx, string(id), newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator())
if err != nil {
return nil, err
}
@ -382,7 +380,7 @@ func (conn *imapConnector) GetMessageLiteral(ctx context.Context, id imap.Messag
func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
defer conn.goPollAPIEvents(false)
if mailboxID == proton.AllMailLabel {
if isAllMailOrScheduled(mailboxID) {
return connector.ErrOperationNotAllowed
}
@ -393,7 +391,7 @@ func (conn *imapConnector) AddMessagesToMailbox(ctx context.Context, messageIDs
func (conn *imapConnector) RemoveMessagesFromMailbox(ctx context.Context, messageIDs []imap.MessageID, mailboxID imap.MailboxID) error {
defer conn.goPollAPIEvents(false)
if mailboxID == proton.AllMailLabel {
if isAllMailOrScheduled(mailboxID) {
return connector.ErrOperationNotAllowed
}
@ -442,8 +440,8 @@ func (conn *imapConnector) MoveMessages(ctx context.Context, messageIDs []imap.M
if (labelFromID == proton.InboxLabel && labelToID == proton.SentLabel) ||
(labelFromID == proton.SentLabel && labelToID == proton.InboxLabel) ||
labelFromID == proton.AllMailLabel ||
labelToID == proton.AllMailLabel {
isAllMailOrScheduled(labelFromID) ||
isAllMailOrScheduled(labelToID) {
return false, connector.ErrOperationNotAllowed
}
@ -507,19 +505,20 @@ func (conn *imapConnector) GetUpdates() <-chan imap.Update {
}, conn.updateChLock)
}
// GetUIDValidity returns the default UID validity for this user.
func (conn *imapConnector) GetUIDValidity() imap.UID {
return conn.vault.GetUIDValidity(conn.addrID)
}
// GetMailboxVisibility returns the visibility of a mailbox over IMAP.
func (conn *imapConnector) GetMailboxVisibility(_ context.Context, mailboxID imap.MailboxID) imap.MailboxVisibility {
switch mailboxID {
case proton.AllMailLabel:
if atomic.LoadUint32(&conn.showAllMail) != 0 {
return imap.Visible
}
return imap.Hidden
// SetUIDValidity sets the default UID validity for this user.
func (conn *imapConnector) SetUIDValidity(validity imap.UID) error {
return conn.vault.SetUIDValidity(conn.addrID, validity)
}
// IsMailboxVisible returns whether this mailbox should be visible over IMAP.
func (conn *imapConnector) IsMailboxVisible(_ context.Context, mailboxID imap.MailboxID) bool {
return atomic.LoadUint32(&conn.showAllMail) != 0 || mailboxID != proton.AllMailLabel
case proton.AllScheduledLabel:
return imap.HiddenIfEmpty
default:
return imap.Visible
}
}
// Close the connector will no longer be used and all resources should be closed/released.
@ -550,7 +549,7 @@ func (conn *imapConnector) importMessage(
messageID = msg.ID
} else {
res, err := stream.Collect(ctx, conn.client.ImportMessages(ctx, addrKR, 1, 1, []proton.ImportReq{{
str, err := conn.client.ImportMessages(ctx, addrKR, 1, 1, []proton.ImportReq{{
Metadata: proton.ImportMetadata{
AddressID: conn.addrID,
LabelIDs: labelIDs,
@ -558,7 +557,12 @@ func (conn *imapConnector) importMessage(
Flags: flags,
},
Message: literal,
}}...))
}}...)
if err != nil {
return fmt.Errorf("failed to prepare message for import: %w", err)
}
res, err := stream.Collect(ctx, str)
if err != nil {
return fmt.Errorf("failed to import message: %w", err)
}
@ -568,7 +572,7 @@ func (conn *imapConnector) importMessage(
var err error
if full, err = conn.client.GetFullMessage(ctx, messageID); err != nil {
if full, err = conn.client.GetFullMessage(ctx, messageID, newProtonAPIScheduler(), proton.NewDefaultAttachmentAllocator()); err != nil {
return fmt.Errorf("failed to fetch message: %w", err)
}
@ -615,7 +619,7 @@ func toIMAPMessage(message proton.MessageMetadata) imap.Message {
}
}
func (conn *imapConnector) createDraft(ctx context.Context, literal []byte, addrKR *crypto.KeyRing, sender proton.Address) (proton.Message, error) { //nolint:funlen
func (conn *imapConnector) createDraft(ctx context.Context, literal []byte, addrKR *crypto.KeyRing, sender proton.Address) (proton.Message, error) {
// Create a new message parser from the reader.
parser, err := parser.New(bytes.NewReader(literal))
if err != nil {
@ -687,3 +691,7 @@ func toIMAPMailbox(label proton.Label, flags, permFlags, attrs imap.FlagSet) ima
Attributes: attrs,
}
}
func isAllMailOrScheduled(mailboxID imap.MailboxID) bool {
return (mailboxID == proton.AllMailLabel) || (mailboxID == proton.AllScheduledLabel)
}

View File

@ -218,8 +218,6 @@ func (h *sendRecorder) getWaitCh(hash string) (<-chan struct{}, bool) {
// - the Content-Type header of each (leaf) part,
// - the Content-Disposition header of each (leaf) part,
// - the (decoded) body of each part.
//
// nolint:funlen
func getMessageHash(b []byte) (string, error) {
section := rfc822.Parse(b)

View File

@ -47,8 +47,6 @@ import (
)
// sendMail sends an email from the given address to the given recipients.
//
// nolint:funlen
func (user *User) sendMail(authID string, from string, to []string, r io.Reader) error {
return safe.RLockRet(func() error {
ctx, cancel := context.WithCancel(context.Background())
@ -165,7 +163,7 @@ func (user *User) sendMail(authID string, from string, to []string, r io.Reader)
}
// sendWithKey sends the message with the given address key.
func sendWithKey( //nolint:funlen
func sendWithKey(
ctx context.Context,
client *proton.Client,
sentry reporter.Reporter,
@ -247,7 +245,7 @@ func sendWithKey( //nolint:funlen
return res, nil
}
func getParentID( //nolint:funlen
func getParentID(
ctx context.Context,
client *proton.Client,
authAddrID string,
@ -375,7 +373,6 @@ func createDraft(
})
}
// nolint:funlen
func createAttachments(
ctx context.Context,
client *proton.Client,
@ -468,12 +465,12 @@ func getRecipients(
prefs, err := parallel.MapContext(ctx, runtime.NumCPU(), addresses, func(ctx context.Context, recipient string) (proton.SendPreferences, error) {
pubKeys, recType, err := client.GetPublicKeys(ctx, recipient)
if err != nil {
return proton.SendPreferences{}, fmt.Errorf("failed to get public keys: %w (%v)", err, recipient)
return proton.SendPreferences{}, fmt.Errorf("failed to get public key for %v: %w", recipient, err)
}
contactSettings, err := getContactSettings(ctx, client, userKR, recipient)
if err != nil {
return proton.SendPreferences{}, fmt.Errorf("failed to get contact settings: %w", err)
return proton.SendPreferences{}, fmt.Errorf("failed to get contact settings for %v: %w", recipient, err)
}
return buildSendPrefs(contactSettings, settings, pubKeys, draft.MIMEType, recType == proton.RecipientTypeInternal)

View File

@ -18,6 +18,7 @@
package user
import (
"bytes"
"context"
"fmt"
"runtime"
@ -25,6 +26,7 @@ import (
"time"
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/gluon/queue"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
@ -34,18 +36,38 @@ import (
"github.com/ProtonMail/proton-bridge/v3/internal/vault"
"github.com/bradenaw/juniper/parallel"
"github.com/bradenaw/juniper/xslices"
"github.com/google/uuid"
"github.com/pbnjay/memory"
"github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"golang.org/x/exp/slices"
)
const (
maxUpdateSize = 1 << 27 // 128 MiB
maxBatchSize = 1 << 8 // 256
)
// syncSystemLabels ensures that system labels are all known to gluon.
func (user *User) syncSystemLabels(ctx context.Context) error {
return safe.RLockRet(func() error {
var updates []imap.Update
// doSync begins syncing the users data.
for _, label := range xslices.Filter(maps.Values(user.apiLabels), func(label proton.Label) bool { return label.Type == proton.LabelTypeSystem }) {
if !wantLabel(label) {
continue
}
for _, updateCh := range xslices.Unique(maps.Values(user.updateCh)) {
update := newSystemMailboxCreatedUpdate(imap.MailboxID(label.ID), label.Name)
updateCh.Enqueue(update)
updates = append(updates, update)
}
}
if err := waitOnIMAPUpdates(ctx, updates); err != nil {
return fmt.Errorf("could not sync system labels: %w", err)
}
return nil
}, user.apiUserLock, user.apiAddrsLock, user.apiLabelsLock, user.updateChLock)
}
// doSync begins syncing the user's data.
// It first ensures the latest event ID is known; if not, it fetches it.
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
// depending on whether the sync was successful.
@ -89,7 +111,6 @@ func (user *User) doSync(ctx context.Context) error {
return nil
}
// nolint:funlen
func (user *User) sync(ctx context.Context) error {
return safe.RLockRet(func() error {
return withAddrKRs(user.apiUser, user.apiAddrs, user.vault.KeyPass(), func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
@ -143,7 +164,7 @@ func (user *User) sync(ctx context.Context) error {
addrKRs,
user.updateCh,
user.eventCh,
user.syncWorkers,
user.maxSyncMemory,
); err != nil {
return fmt.Errorf("failed to sync messages: %w", err)
}
@ -166,7 +187,7 @@ func (user *User) sync(ctx context.Context) error {
func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh ...*queue.QueuedChannel[imap.Update]) error {
var updates []imap.Update
// Create placeholder Folders/Labels mailboxes with a random ID and with the \Noselect attribute.
// Create placeholder Folders/Labels mailboxes with the \Noselect attribute.
for _, prefix := range []string{folderPrefix, labelPrefix} {
for _, updateCh := range updateCh {
update := newPlaceHolderMailboxCreatedUpdate(prefix)
@ -212,7 +233,15 @@ func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh
return nil
}
// nolint:funlen
const Kilobyte = uint64(1024)
const Megabyte = 1024 * Kilobyte
const Gigabyte = 1024 * Megabyte
func toMB(v uint64) float64 {
return float64(v) / float64(Megabyte)
}
// nolint:gocyclo
func syncMessages(
ctx context.Context,
userID string,
@ -224,7 +253,7 @@ func syncMessages(
addrKRs map[string]*crypto.KeyRing,
updateCh map[string]*queue.QueuedChannel[imap.Update],
eventCh *queue.QueuedChannel[events.Event],
syncWorkers int,
maxSyncMemory uint64,
) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
@ -235,78 +264,296 @@ func syncMessages(
logrus.WithFields(logrus.Fields{
"messages": len(messageIDs),
"workers": syncWorkers,
"numCPU": runtime.NumCPU(),
}).Info("Starting message sync")
// Create the flushers, one per update channel.
flushers := make(map[string]*flusher, len(updateCh))
for addrID, updateCh := range updateCh {
flushers[addrID] = newFlusher(updateCh, maxUpdateSize)
}
// Create a reporter to report sync progress updates.
syncReporter := newSyncReporter(userID, eventCh, len(messageIDs), time.Second)
defer syncReporter.done()
type flushUpdate struct {
messageID string
pushedUpdates []imap.Update
batchLen int
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
// times x due to pipeline and all additional memory used by network requests and compression+io.
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
// returns as we can't keep the pipeline fed fast enough.
const MaxDownloadRequestMem = 128 * Megabyte
// Any lower than this and we may fail to download messages.
const MinDownloadRequestMem = 40 * Megabyte
// This value can be increased to your hearts content. The more system memory the user has, the more messages
// we can build in parallel.
const MaxMessageBuildingMem = 128 * Megabyte
const MinMessageBuildingMem = 64 * Megabyte
// Maximum recommend value for parallel downloads by the API team.
const maxParallelDownloads = 20
totalMemory := memory.TotalMemory()
if maxSyncMemory >= totalMemory/2 {
logrus.Warnf("Requested max sync memory of %v MB is greater than half of system memory (%v MB), forcing to half of system memory",
maxSyncMemory, toMB(totalMemory/2))
maxSyncMemory = totalMemory / 2
}
if maxSyncMemory < 800*Megabyte {
logrus.Warnf("Requested max sync memory of %v MB, but minimum recommended is 800 MB, forcing max syncMemory to 800MB", toMB(maxSyncMemory))
maxSyncMemory = 800 * Megabyte
}
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
syncMaxDownloadRequestMem := MaxDownloadRequestMem
syncMaxMessageBuildingMem := MaxMessageBuildingMem
// If less than 2GB available try and limit max memory to 512 MB
switch {
case maxSyncMemory < 2*Gigabyte:
if maxSyncMemory < 800*Megabyte {
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
}
syncMaxDownloadRequestMem = MinDownloadRequestMem
syncMaxMessageBuildingMem = MinMessageBuildingMem
case maxSyncMemory == 2*Gigabyte:
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
// updates. Additionally, most of ot sync time is spent in the message building.
syncMaxDownloadRequestMem = MaxDownloadRequestMem
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
syncMaxMessageBuildingMem = MaxMessageBuildingMem
default:
// Divide by 8 as download stage and build stage will use aprox. 4x the specified memory.
remainingMemory := (maxSyncMemory - 2*Gigabyte) / 8
syncMaxDownloadRequestMem = MaxDownloadRequestMem + remainingMemory
syncMaxMessageBuildingMem = MaxMessageBuildingMem + remainingMemory
}
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
toMB(syncMaxDownloadRequestMem),
toMB(syncMaxMessageBuildingMem),
toMB((syncMaxMessageBuildingMem*4)+(syncMaxDownloadRequestMem*4)),
)
type flushUpdate struct {
messageID string
err error
batchLen int
}
type downloadRequest struct {
ids []string
expectedSize uint64
err error
}
type downloadedMessageBatch struct {
batch []proton.FullMessage
}
type builtMessageBatch struct {
batch []*buildRes
}
downloadCh := make(chan downloadRequest)
buildCh := make(chan downloadedMessageBatch)
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
// to the update flushing goroutine.
flushCh := make(chan []*buildRes, 2)
flushCh := make(chan builtMessageBatch)
// Allow up to 4 batched wait requests.
flushUpdateCh := make(chan flushUpdate, 4)
flushUpdateCh := make(chan flushUpdate)
errorCh := make(chan error, syncWorkers)
errorCh := make(chan error, maxParallelDownloads*4)
// Go routine in charge of downloading message metadata
logging.GoAnnotated(ctx, func(ctx context.Context) {
defer close(downloadCh)
const MetadataDataPageSize = 150
var downloadReq downloadRequest
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
metadataChunks := xslices.Chunk(messageIDs, MetadataDataPageSize)
for i, metadataChunk := range metadataChunks {
logrus.Debugf("Metadata Request (%v of %v), previous: %v", i, len(metadataChunks), len(downloadReq.ids))
metadata, err := client.GetMessageMetadataPage(ctx, 0, len(metadataChunk), proton.MessageFilter{ID: metadataChunk})
if err != nil {
downloadReq.err = err
select {
case downloadCh <- downloadReq:
case <-ctx.Done():
return
}
return
}
if ctx.Err() != nil {
return
}
// Build look up table so that messages are processed in the same order.
metadataMap := make(map[string]int, len(metadata))
for i, v := range metadata {
metadataMap[v.ID] = i
}
for i, id := range metadataChunk {
m := &metadata[metadataMap[id]]
nextSize := downloadReq.expectedSize + uint64(m.Size)
if nextSize >= syncMaxDownloadRequestMem || len(downloadReq.ids) >= 256 {
logrus.Debugf("Download Request Sent at %v of %v", i, len(metadata))
select {
case downloadCh <- downloadReq:
case <-ctx.Done():
return
}
downloadReq.expectedSize = 0
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
nextSize = uint64(m.Size)
}
downloadReq.ids = append(downloadReq.ids, id)
downloadReq.expectedSize = nextSize
}
}
if len(downloadReq.ids) != 0 {
logrus.Debugf("Sending remaining download request")
select {
case downloadCh <- downloadReq:
case <-ctx.Done():
return
}
}
}, logging.Labels{"sync-stage": "meta-data"})
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
go func() {
defer close(flushCh)
logging.GoAnnotated(ctx, func(ctx context.Context) {
defer close(buildCh)
defer close(errorCh)
defer func() {
logrus.Debugf("sync downloader exit")
}()
attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads)
defer attachmentDownloader.close()
for request := range downloadCh {
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
if request.err != nil {
errorCh <- request.err
return
}
for _, batch := range xslices.Chunk(messageIDs, maxBatchSize) {
if ctx.Err() != nil {
errorCh <- ctx.Err()
return
}
result, err := parallel.MapContext(ctx, syncWorkers, batch, func(ctx context.Context, id string) (*buildRes, error) {
msg, err := client.GetFullMessage(ctx, id)
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
var result proton.FullMessage
msg, err := client.GetMessage(ctx, id)
if err != nil {
return nil, err
return proton.FullMessage{}, err
}
if ctx.Err() != nil {
return nil, ctx.Err()
attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments)
if err != nil {
return proton.FullMessage{}, err
}
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID]), nil
result.Message = msg
result.AttData = attachments
return result, nil
})
if err != nil {
errorCh <- err
return
}
select {
case buildCh <- downloadedMessageBatch{
batch: result,
}:
case <-ctx.Done():
return
}
}
}, logging.Labels{"sync-stage": "download"})
// Goroutine which builds messages after they have been downloaded
logging.GoAnnotated(ctx, func(ctx context.Context) {
defer close(flushCh)
defer func() {
logrus.Debugf("sync builder exit")
}()
maxMessagesInParallel := runtime.NumCPU()
for buildBatch := range buildCh {
if ctx.Err() != nil {
errorCh <- ctx.Err()
return
}
flushCh <- result
chunks := chunkSyncBuilderBatch(buildBatch.batch, syncMaxMessageBuildingMem)
for index, chunk := range chunks {
logrus.Debugf("Build request: %v of %v count=%v", index, len(chunks), len(chunk))
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
})
if err != nil {
return
}
select {
case flushCh <- builtMessageBatch{result}:
case <-ctx.Done():
return
}
}
}
}()
}, logging.Labels{"sync-stage": "builder"})
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
go func() {
logging.GoAnnotated(ctx, func(ctx context.Context) {
defer close(flushUpdateCh)
for batch := range flushCh {
for _, res := range batch {
defer func() {
logrus.Debugf("sync flush exit")
}()
type updateTargetInfo struct {
queueIndex int
ch *queue.QueuedChannel[imap.Update]
}
pendingUpdates := make([][]*imap.MessageCreated, len(updateCh))
addressToIndex := make(map[string]updateTargetInfo)
{
i := 0
for addrID, updateCh := range updateCh {
addressToIndex[addrID] = updateTargetInfo{
ch: updateCh,
queueIndex: i,
}
i++
}
}
for downloadBatch := range flushCh {
logrus.Debugf("Flush batch: %v", len(downloadBatch.batch))
for _, res := range downloadBatch.batch {
if res.err != nil {
if err := vault.AddFailedMessageID(res.messageID); err != nil {
logrus.WithError(err).Error("Failed to add failed message ID")
@ -327,31 +574,38 @@ func syncMessages(
}
}
flushers[res.addressID].push(res.update)
targetInfo := addressToIndex[res.addressID]
pendingUpdates[targetInfo.queueIndex] = append(pendingUpdates[targetInfo.queueIndex], res.update)
}
var pushedUpdates []imap.Update
for _, flusher := range flushers {
flusher.flush()
pushedUpdates = append(pushedUpdates, flusher.collectPushedUpdates()...)
for _, info := range addressToIndex {
up := imap.NewMessagesCreated(true, pendingUpdates[info.queueIndex]...)
info.ch.Enqueue(up)
err, ok := up.WaitContext(ctx)
if ok && err != nil {
flushUpdateCh <- flushUpdate{
err: fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err),
}
return
}
pendingUpdates[info.queueIndex] = pendingUpdates[info.queueIndex][:0]
}
flushUpdateCh <- flushUpdate{
messageID: batch[0].messageID,
pushedUpdates: pushedUpdates,
batchLen: len(batch),
select {
case flushUpdateCh <- flushUpdate{
messageID: downloadBatch.batch[0].messageID,
err: nil,
batchLen: len(downloadBatch.batch),
}:
case <-ctx.Done():
return
}
}
}()
}, logging.Labels{"sync-stage": "flush"})
for flushUpdate := range flushUpdateCh {
for _, up := range flushUpdate.pushedUpdates {
err, ok := up.WaitContext(ctx)
if ok && err != nil {
return fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err)
}
}
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
return fmt.Errorf("failed to set last synced message ID: %w", err)
}
@ -394,6 +648,9 @@ func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *im
case proton.StarredLabel:
attrs = attrs.Add(imap.AttrFlagged)
case proton.AllScheduledLabel:
labelName = "Scheduled" // API actual name is "All Scheduled"
}
return imap.NewMailboxCreated(imap.Mailbox{
@ -407,7 +664,7 @@ func newSystemMailboxCreatedUpdate(labelID imap.MailboxID, labelName string) *im
func newPlaceHolderMailboxCreatedUpdate(labelName string) *imap.MailboxCreated {
return imap.NewMailboxCreated(imap.Mailbox{
ID: imap.MailboxID(uuid.NewString()),
ID: imap.MailboxID(labelName),
Name: []string{labelName},
Flags: defaultFlags,
PermanentFlags: defaultPermanentFlags,
@ -456,6 +713,9 @@ func wantLabel(label proton.Label) bool {
case proton.StarredLabel:
return true
case proton.AllScheduledLabel:
return true
default:
return false
}
@ -471,3 +731,126 @@ func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string {
return wantLabel(apiLabel)
})
}
type attachmentResult struct {
attachment []byte
err error
}
type attachmentJob struct {
id string
size int64
result chan attachmentResult
}
type attachmentDownloader struct {
workerCh chan attachmentJob
cancel context.CancelFunc
}
func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) {
for {
select {
case <-ctx.Done():
return
case job, ok := <-work:
if !ok {
return
}
var b bytes.Buffer
b.Grow(int(job.size))
err := client.GetAttachmentInto(ctx, job.id, &b)
select {
case <-ctx.Done():
close(job.result)
return
case job.result <- attachmentResult{attachment: b.Bytes(), err: err}:
close(job.result)
}
}
}
}
func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader {
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
ctx, cancel := context.WithCancel(ctx)
for i := 0; i < workerCount; i++ {
workerCh = make(chan attachmentJob)
logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{
"sync": fmt.Sprintf("att-downloader %v", i),
})
}
return &attachmentDownloader{
workerCh: workerCh,
cancel: cancel,
}
}
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
resultChs := make([]chan attachmentResult, len(attachments))
for i, id := range attachments {
resultChs[i] = make(chan attachmentResult, 1)
select {
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
case <-ctx.Done():
return nil, ctx.Err()
}
}
result := make([][]byte, len(attachments))
var err error
for i := 0; i < len(attachments); i++ {
select {
case <-ctx.Done():
return nil, ctx.Err()
case r := <-resultChs[i]:
if r.err != nil {
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
}
result[i] = r.attachment
}
}
return result, err
}
func (a *attachmentDownloader) close() {
a.cancel()
}
func chunkSyncBuilderBatch(batch []proton.FullMessage, maxMemory uint64) [][]proton.FullMessage {
var expectedMemUsage uint64
var chunks [][]proton.FullMessage
var lastIndex int
var index int
for _, v := range batch {
var dataSize uint64
for _, a := range v.Attachments {
dataSize += uint64(a.Size)
}
// 2x increase for attachment due to extra memory needed for decrypting and writing
// in memory buffer.
dataSize *= 2
dataSize += uint64(len(v.Body))
nextMemSize := expectedMemUsage + dataSize
if nextMemSize >= maxMemory {
chunks = append(chunks, batch[lastIndex:index])
lastIndex = index
expectedMemUsage = dataSize
} else {
expectedMemUsage = nextMemSize
}
index++
}
if lastIndex < len(batch) {
chunks = append(chunks, batch[lastIndex:])
}
return chunks
}

View File

@ -48,16 +48,18 @@ func defaultJobOpts() message.JobOptions {
}
}
func buildRFC822(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing) *buildRes {
func buildRFC822(apiLabels map[string]proton.Label, full proton.FullMessage, addrKR *crypto.KeyRing, buffer *bytes.Buffer) *buildRes {
var (
update *imap.MessageCreated
err error
)
if literal, buildErr := message.BuildRFC822(addrKR, full.Message, full.AttData, defaultJobOpts()); buildErr != nil {
buffer.Grow(full.Size)
if buildErr := message.BuildRFC822Into(addrKR, full.Message, full.AttData, defaultJobOpts(), buffer); buildErr != nil {
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, buildErr)
err = buildErr
} else if created, parseErr := newMessageCreatedUpdate(apiLabels, full.MessageMetadata, literal); parseErr != nil {
} else if created, parseErr := newMessageCreatedUpdate(apiLabels, full.MessageMetadata, buffer.Bytes()); parseErr != nil {
update = newMessageCreatedFailedUpdate(apiLabels, full.MessageMetadata, parseErr)
err = parseErr
} else {

View File

@ -24,6 +24,8 @@ import (
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/rfc822"
"github.com/ProtonMail/go-proton-api"
"github.com/bradenaw/juniper/xslices"
"github.com/stretchr/testify/require"
)
@ -47,3 +49,32 @@ func TestNewFailedMessageLiteral(t *testing.T) {
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2)`, parsed.Body)
require.Equal(t, `("text" "plain" () NIL NIL "base64" 114 2 NIL NIL NIL NIL)`, parsed.Structure)
}
func TestSyncChunkSyncBuilderBatch(t *testing.T) {
// GODT-2424 - Some messages were not fully built due to a bug in the chunking if the total memory used by the
// message would be higher than the maximum we allowed.
const totalMessageCount = 100
msg := proton.FullMessage{
Message: proton.Message{
Attachments: []proton.Attachment{
{
Size: int64(8 * Megabyte),
},
},
},
AttData: nil,
}
messages := xslices.Repeat(msg, totalMessageCount)
chunks := chunkSyncBuilderBatch(messages, 16*Megabyte)
var totalMessagesInChunks int
for _, v := range chunks {
totalMessagesInChunks += len(v)
}
require.Equal(t, totalMessagesInChunks, totalMessageCount)
}

View File

@ -1,63 +0,0 @@
// Copyright (c) 2023 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package user
import (
"github.com/ProtonMail/gluon/imap"
"github.com/ProtonMail/gluon/queue"
)
type flusher struct {
updateCh *queue.QueuedChannel[imap.Update]
updates []*imap.MessageCreated
pushedUpdates []imap.Update
maxUpdateSize int
curChunkSize int
}
func newFlusher(updateCh *queue.QueuedChannel[imap.Update], maxUpdateSize int) *flusher {
return &flusher{
updateCh: updateCh,
maxUpdateSize: maxUpdateSize,
}
}
func (f *flusher) push(update *imap.MessageCreated) {
f.updates = append(f.updates, update)
if f.curChunkSize += len(update.Literal); f.curChunkSize >= f.maxUpdateSize {
f.flush()
}
}
func (f *flusher) flush() {
if len(f.updates) > 0 {
update := imap.NewMessagesCreated(true, f.updates...)
f.updateCh.Enqueue(update)
f.updates = nil
f.curChunkSize = 0
f.pushedUpdates = append(f.pushedUpdates, update)
}
}
func (f *flusher) collectPushedUpdates() []imap.Update {
updates := f.pushedUpdates
f.pushedUpdates = nil
return updates
}

View File

@ -20,6 +20,7 @@ package user
import (
"fmt"
"reflect"
"runtime"
"strings"
"github.com/ProtonMail/go-proton-api"
@ -91,3 +92,7 @@ func sortSlice[Item any](items []Item, less func(Item, Item) bool) []Item {
return sorted
}
func newProtonAPIScheduler() proton.Scheduler {
return proton.NewParallelScheduler(runtime.NumCPU() / 2)
}

View File

@ -88,13 +88,12 @@ type User struct {
pollAPIEventsCh chan chan struct{}
goPollAPIEvents func(wait bool)
syncWorkers int
showAllMail uint32
maxSyncMemory uint64
}
// New returns a new user.
//
// nolint:funlen
func New(
ctx context.Context,
encVault *vault.User,
@ -102,9 +101,9 @@ func New(
reporter reporter.Reporter,
apiUser proton.User,
crashHandler async.PanicHandler,
syncWorkers int,
showAllMail bool,
) (*User, error) { //nolint:funlen
maxSyncMemory uint64,
) (*User, error) {
logrus.WithField("userID", apiUser.ID).Info("Creating new user")
// Get the user's API addresses.
@ -146,8 +145,9 @@ func New(
tasks: async.NewGroup(context.Background(), crashHandler),
pollAPIEventsCh: make(chan chan struct{}),
syncWorkers: syncWorkers,
showAllMail: b32(showAllMail),
maxSyncMemory: maxSyncMemory,
}
// Initialize the user's update channels for its current address mode.
@ -193,7 +193,15 @@ func New(
// Sync the user.
user.syncAbort.Do(ctx, func(ctx context.Context) {
if user.vault.SyncStatus().IsComplete() {
user.log.Info("Sync already complete, skipping")
user.log.Info("Sync already complete, only system label will be updated")
if err := user.syncSystemLabels(ctx); err != nil {
user.log.WithError(err).Error("Failed to update system labels")
return
}
user.log.Info("System label update complete, starting API event stream")
return
}
@ -257,11 +265,13 @@ func (user *User) Match(query string) bool {
}, user.apiUserLock, user.apiAddrsLock)
}
// Emails returns all the user's email addresses.
// Emails returns all the user's active email addresses.
// It returns them in sorted order; the user's primary address is first.
func (user *User) Emails() []string {
return safe.RLockRet(func() []string {
addresses := maps.Values(user.apiAddrs)
addresses := xslices.Filter(maps.Values(user.apiAddrs), func(addr proton.Address) bool {
return addr.Status == proton.AddressStatusEnabled
})
slices.SortFunc(addresses, func(a, b proton.Address) bool {
return a.Order < b.Order
@ -432,8 +442,6 @@ func (user *User) NewIMAPConnectors() (map[string]connector.Connector, error) {
}
// SendMail sends an email from the given address to the given recipients.
//
// nolint:funlen
func (user *User) SendMail(authID string, from string, to []string, r io.Reader) error {
if user.vault.SyncStatus().IsComplete() {
defer user.goPollAPIEvents(true)
@ -636,13 +644,11 @@ func (user *User) startEvents(ctx context.Context) {
}
// doEventPoll is called whenever API events should be polled.
//
//nolint:funlen
func (user *User) doEventPoll(ctx context.Context) error {
user.eventLock.Lock()
defer user.eventLock.Unlock()
event, err := user.client.GetEvent(ctx, user.vault.EventID())
event, _, err := user.client.GetEvent(ctx, user.vault.EventID())
if err != nil {
return fmt.Errorf("failed to get event (caused by %T): %w", internal.ErrCause(err), err)
}

View File

@ -119,14 +119,14 @@ func withUser(tb testing.TB, ctx context.Context, _ *server.Server, m *proton.Ma
saltedKeyPass, err := salts.SaltForKey([]byte(password), apiUser.Keys.Primary().ID)
require.NoError(tb, err)
vault, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
v, corrupt, err := vault.New(tb.TempDir(), tb.TempDir(), []byte("my secret key"))
require.NoError(tb, err)
require.False(tb, corrupt)
vaultUser, err := vault.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
vaultUser, err := v.AddUser(apiUser.ID, username, username+"@pm.me", apiAuth.UID, apiAuth.RefreshToken, saltedKeyPass)
require.NoError(tb, err)
user, err := New(ctx, vaultUser, client, nil, apiUser, nil, vault.SyncWorkers(), true)
user, err := New(ctx, vaultUser, client, nil, apiUser, nil, true, vault.DefaultMaxSyncMemory)
require.NoError(tb, err)
defer user.Close()