mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-18 16:17:03 +00:00
GODT-2224: Refactor bridge sync to use less memory
Updates go-proton-api and Gluon to includes memory reduction changes and modify the sync process to take into account how much memory is used during the sync stage. The sync process now has an extra stage which first download the message metada to ensure that we only download up to `syncMaxDownloadRequesMem` messages or 250 messages total. This allows for scaling the download request automatically to accommodate many small or few very large messages. The IDs are then sent to a download go-routine which downloads the message and its attachments. The result is then forwarded to another go-routine which builds the actual message. This stage tries to ensure that we don't use more than `syncMaxMessageBuildingMem` to build these messages. Finally the result is sent to a last go-routine which applies the changes to Gluon and waits for them to be completed. The new process is currently limited to 2GB. Dynamic scaling will be implemented in a follow up. For systems with less than 2GB of memory we limit the values to a set of values that is known to work.
This commit is contained in:
@ -18,6 +18,7 @@
|
||||
package user
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"runtime"
|
||||
@ -25,6 +26,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/ProtonMail/gluon/imap"
|
||||
"github.com/ProtonMail/gluon/logging"
|
||||
"github.com/ProtonMail/gluon/queue"
|
||||
"github.com/ProtonMail/gluon/reporter"
|
||||
"github.com/ProtonMail/go-proton-api"
|
||||
@ -35,16 +37,12 @@ import (
|
||||
"github.com/bradenaw/juniper/parallel"
|
||||
"github.com/bradenaw/juniper/xslices"
|
||||
"github.com/google/uuid"
|
||||
"github.com/pbnjay/memory"
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.org/x/exp/slices"
|
||||
)
|
||||
|
||||
const (
|
||||
maxUpdateSize = 1 << 27 // 128 MiB
|
||||
maxBatchSize = 1 << 8 // 256
|
||||
)
|
||||
|
||||
// doSync begins syncing the users data.
|
||||
// It first ensures the latest event ID is known; if not, it fetches it.
|
||||
// It sends a SyncStarted event and then either SyncFinished or SyncFailed
|
||||
@ -143,7 +141,6 @@ func (user *User) sync(ctx context.Context) error {
|
||||
addrKRs,
|
||||
user.updateCh,
|
||||
user.eventCh,
|
||||
user.syncWorkers,
|
||||
); err != nil {
|
||||
return fmt.Errorf("failed to sync messages: %w", err)
|
||||
}
|
||||
@ -212,7 +209,15 @@ func syncLabels(ctx context.Context, apiLabels map[string]proton.Label, updateCh
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint:funlen
|
||||
const Kilobyte = uint64(1024)
|
||||
const Megabyte = 1024 * Kilobyte
|
||||
const Gigabyte = 1024 * Megabyte
|
||||
|
||||
func toMB(v uint64) float64 {
|
||||
return float64(v) / float64(Megabyte)
|
||||
}
|
||||
|
||||
// nolint:funlen,gocyclo
|
||||
func syncMessages(
|
||||
ctx context.Context,
|
||||
userID string,
|
||||
@ -224,7 +229,6 @@ func syncMessages(
|
||||
addrKRs map[string]*crypto.KeyRing,
|
||||
updateCh map[string]*queue.QueuedChannel[imap.Update],
|
||||
eventCh *queue.QueuedChannel[events.Event],
|
||||
syncWorkers int,
|
||||
) error {
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
@ -235,78 +239,319 @@ func syncMessages(
|
||||
|
||||
logrus.WithFields(logrus.Fields{
|
||||
"messages": len(messageIDs),
|
||||
"workers": syncWorkers,
|
||||
"numCPU": runtime.NumCPU(),
|
||||
}).Info("Starting message sync")
|
||||
|
||||
// Create the flushers, one per update channel.
|
||||
flushers := make(map[string]*flusher, len(updateCh))
|
||||
|
||||
for addrID, updateCh := range updateCh {
|
||||
flushers[addrID] = newFlusher(updateCh, maxUpdateSize)
|
||||
}
|
||||
|
||||
// Create a reporter to report sync progress updates.
|
||||
syncReporter := newSyncReporter(userID, eventCh, len(messageIDs), time.Second)
|
||||
defer syncReporter.done()
|
||||
|
||||
type flushUpdate struct {
|
||||
messageID string
|
||||
pushedUpdates []imap.Update
|
||||
batchLen int
|
||||
// Expected mem usage for this whole process should be the sum of MaxMessageBuildingMem and MaxDownloadRequestMem
|
||||
// times x due to pipeline and all additional memory used by network requests and compression+io.
|
||||
|
||||
// There's no point in using more than 128MB of download data per stage, after that we reach a point of diminishing
|
||||
// returns as we can't keep the pipeline fed fast enough.
|
||||
const MaxDownloadRequestMem = 128 * Megabyte
|
||||
|
||||
// Any lower than this and we may fail to download messages.
|
||||
const MinDownloadRequestMem = 40 * Megabyte
|
||||
|
||||
// This value can be increased to your hearts content. The more system memory the user has, the more messages
|
||||
// we can build in parallel.
|
||||
const MaxMessageBuildingMem = 128 * Megabyte
|
||||
const MinMessageBuildingMem = 64 * Megabyte
|
||||
|
||||
// Maximum recommend value for parallel downloads by the API team.
|
||||
const maxParallelDownloads = 20
|
||||
|
||||
totalMemory := memory.TotalMemory()
|
||||
logrus.Debugf("Total System Memory: %v", toMB(totalMemory))
|
||||
|
||||
syncMaxDownloadRequestMem := MaxDownloadRequestMem
|
||||
syncMaxMessageBuildingMem := MaxMessageBuildingMem
|
||||
|
||||
// If less than 2GB available try and limit max memory to 512 MB
|
||||
if totalMemory < 2*Gigabyte {
|
||||
if totalMemory < 800*Megabyte {
|
||||
logrus.Warnf("System has less than 800MB of memory, you may experience issues sycing large mailboxes")
|
||||
}
|
||||
syncMaxDownloadRequestMem = MinDownloadRequestMem
|
||||
syncMaxMessageBuildingMem = MinMessageBuildingMem
|
||||
} else {
|
||||
// Increasing the max download capacity has very little effect on sync speed. We could increase the download
|
||||
// memory but the user would see less sync notifications. A smaller value here leads to more frequent
|
||||
// updates. Additionally, most of ot sync time is spent in the message building.
|
||||
syncMaxDownloadRequestMem = MaxDownloadRequestMem
|
||||
// Currently limited so that if a user has multiple accounts active it also doesn't cause excessive memory usage.
|
||||
syncMaxMessageBuildingMem = MaxMessageBuildingMem
|
||||
}
|
||||
|
||||
logrus.Debugf("Max memory usage for sync Download=%vMB Building=%vMB Predicted Max Total=%vMB",
|
||||
toMB(syncMaxDownloadRequestMem),
|
||||
toMB(syncMaxMessageBuildingMem),
|
||||
toMB((syncMaxMessageBuildingMem*4)+(syncMaxDownloadRequestMem*4)),
|
||||
)
|
||||
|
||||
type flushUpdate struct {
|
||||
messageID string
|
||||
err error
|
||||
batchLen int
|
||||
}
|
||||
|
||||
type downloadRequest struct {
|
||||
ids []string
|
||||
expectedSize uint64
|
||||
err error
|
||||
}
|
||||
|
||||
type downloadedMessageBatch struct {
|
||||
batch []proton.FullMessage
|
||||
}
|
||||
|
||||
type builtMessageBatch struct {
|
||||
batch []*buildRes
|
||||
}
|
||||
|
||||
downloadCh := make(chan downloadRequest)
|
||||
|
||||
buildCh := make(chan downloadedMessageBatch)
|
||||
|
||||
// The higher this value, the longer we can continue our download iteration before being blocked on channel writes
|
||||
// to the update flushing goroutine.
|
||||
flushCh := make(chan []*buildRes, 2)
|
||||
flushCh := make(chan builtMessageBatch)
|
||||
|
||||
// Allow up to 4 batched wait requests.
|
||||
flushUpdateCh := make(chan flushUpdate, 4)
|
||||
flushUpdateCh := make(chan flushUpdate)
|
||||
|
||||
errorCh := make(chan error, syncWorkers)
|
||||
errorCh := make(chan error, maxParallelDownloads+2)
|
||||
|
||||
// Go routine in charge of downloading message metadata
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(downloadCh)
|
||||
const MetadataDataPageSize = 150
|
||||
|
||||
var downloadReq downloadRequest
|
||||
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
|
||||
|
||||
metadataChunks := xslices.Chunk(messageIDs, MetadataDataPageSize)
|
||||
for i, metadataChunk := range metadataChunks {
|
||||
logrus.Debugf("Metadata Request (%v of %v), previous: %v", i, len(metadataChunks), len(downloadReq.ids))
|
||||
metadata, err := client.GetMessageMetadataPage(ctx, 0, len(metadataChunk), proton.MessageFilter{ID: metadataChunk})
|
||||
if err != nil {
|
||||
downloadReq.err = err
|
||||
downloadCh <- downloadReq
|
||||
return
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
downloadReq.err = err
|
||||
downloadCh <- downloadReq
|
||||
return
|
||||
}
|
||||
|
||||
// Build look up table so that messages are processed in the same order.
|
||||
metadataMap := make(map[string]int, len(metadata))
|
||||
for i, v := range metadata {
|
||||
metadataMap[v.ID] = i
|
||||
}
|
||||
|
||||
for i, id := range metadataChunk {
|
||||
m := &metadata[metadataMap[id]]
|
||||
nextSize := downloadReq.expectedSize + uint64(m.Size)
|
||||
if nextSize >= syncMaxDownloadRequestMem || len(downloadReq.ids) >= 256 {
|
||||
logrus.Debugf("Download Request Sent at %v of %v", i, len(metadata))
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
downloadReq.expectedSize = 0
|
||||
downloadReq.ids = make([]string, 0, MetadataDataPageSize)
|
||||
nextSize = uint64(m.Size)
|
||||
}
|
||||
downloadReq.ids = append(downloadReq.ids, id)
|
||||
downloadReq.expectedSize = nextSize
|
||||
}
|
||||
}
|
||||
|
||||
if len(downloadReq.ids) != 0 {
|
||||
logrus.Debugf("Sending remaining download request")
|
||||
select {
|
||||
case downloadCh <- downloadReq:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "meta-data"})
|
||||
|
||||
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
||||
go func() {
|
||||
defer close(flushCh)
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(buildCh)
|
||||
defer close(errorCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync downloader exit")
|
||||
}()
|
||||
|
||||
for request := range downloadCh {
|
||||
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
|
||||
if request.err != nil {
|
||||
errorCh <- request.err
|
||||
return
|
||||
}
|
||||
|
||||
for _, batch := range xslices.Chunk(messageIDs, maxBatchSize) {
|
||||
if ctx.Err() != nil {
|
||||
errorCh <- ctx.Err()
|
||||
return
|
||||
}
|
||||
|
||||
result, err := parallel.MapContext(ctx, syncWorkers, batch, func(ctx context.Context, id string) (*buildRes, error) {
|
||||
msg, err := client.GetFullMessage(ctx, id)
|
||||
result, err := parallel.MapContext(ctx, maxParallelDownloads, request.ids, func(ctx context.Context, id string) (proton.FullMessage, error) {
|
||||
var result proton.FullMessage
|
||||
|
||||
msg, err := client.GetMessage(ctx, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
if ctx.Err() != nil {
|
||||
return nil, ctx.Err()
|
||||
var attachmentSize int64
|
||||
for _, a := range msg.Attachments {
|
||||
attachmentSize += a.Size
|
||||
}
|
||||
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID]), nil
|
||||
// allocate attachment data.
|
||||
result.AttData = make([][]byte, len(msg.Attachments))
|
||||
|
||||
for i, a := range msg.Attachments {
|
||||
var buffer bytes.Buffer
|
||||
buffer.Grow(int(a.Size))
|
||||
if err := client.GetAttachmentInto(ctx, a.ID, &buffer); err != nil {
|
||||
return proton.FullMessage{}, err
|
||||
}
|
||||
|
||||
result.AttData[i] = buffer.Bytes()
|
||||
}
|
||||
|
||||
result.Message = msg
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
errorCh <- err
|
||||
return
|
||||
}
|
||||
|
||||
select {
|
||||
case buildCh <- downloadedMessageBatch{
|
||||
batch: result,
|
||||
}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}, logging.Labels{"sync-stage": "download"})
|
||||
|
||||
// Goroutine which builds messages after they have been downloaded
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(flushCh)
|
||||
defer func() {
|
||||
logrus.Debugf("sync builder exit")
|
||||
}()
|
||||
|
||||
maxMessagesInParallel := runtime.NumCPU()
|
||||
|
||||
for buildBatch := range buildCh {
|
||||
if ctx.Err() != nil {
|
||||
errorCh <- ctx.Err()
|
||||
return
|
||||
}
|
||||
|
||||
flushCh <- result
|
||||
var expectedMemUsage uint64
|
||||
var chunks [][]proton.FullMessage
|
||||
|
||||
{
|
||||
var lastIndex int
|
||||
var index int
|
||||
for _, v := range buildBatch.batch {
|
||||
var dataSize uint64
|
||||
for _, a := range v.Attachments {
|
||||
dataSize += uint64(a.Size)
|
||||
}
|
||||
|
||||
// 2x increase for attachment due to extra memory needed for decrypting and writing
|
||||
// in memory buffer.
|
||||
dataSize *= 2
|
||||
dataSize += uint64(len(v.Body))
|
||||
|
||||
nextMemSize := expectedMemUsage + dataSize
|
||||
if nextMemSize >= syncMaxMessageBuildingMem {
|
||||
chunks = append(chunks, buildBatch.batch[lastIndex:index])
|
||||
lastIndex = index
|
||||
expectedMemUsage = dataSize
|
||||
} else {
|
||||
expectedMemUsage = nextMemSize
|
||||
}
|
||||
|
||||
index++
|
||||
}
|
||||
|
||||
if index < len(buildBatch.batch) {
|
||||
chunks = append(chunks, buildBatch.batch[index:])
|
||||
} else if index == len(buildBatch.batch) && len(chunks) == 0 {
|
||||
chunks = [][]proton.FullMessage{buildBatch.batch}
|
||||
}
|
||||
}
|
||||
|
||||
for index, chunk := range chunks {
|
||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (*buildRes, error) {
|
||||
return buildRFC822(apiLabels, msg, addrKRs[msg.AddressID], new(bytes.Buffer)), nil
|
||||
})
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
logrus.Debugf("Build request: %v of %v", index, len(chunks))
|
||||
|
||||
select {
|
||||
case flushCh <- builtMessageBatch{result}:
|
||||
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, logging.Labels{"sync-stage": "builder"})
|
||||
|
||||
// Goroutine which converts the messages into updates and builds a waitable structure for progress tracking.
|
||||
go func() {
|
||||
logging.GoAnnotated(ctx, func(ctx context.Context) {
|
||||
defer close(flushUpdateCh)
|
||||
for batch := range flushCh {
|
||||
for _, res := range batch {
|
||||
defer func() {
|
||||
logrus.Debugf("sync flush exit")
|
||||
}()
|
||||
|
||||
type updateTargetInfo struct {
|
||||
queueIndex int
|
||||
ch *queue.QueuedChannel[imap.Update]
|
||||
}
|
||||
|
||||
pendingUpdates := make([][]*imap.MessageCreated, len(updateCh))
|
||||
addressToIndex := make(map[string]updateTargetInfo)
|
||||
|
||||
{
|
||||
i := 0
|
||||
for addrID, updateCh := range updateCh {
|
||||
addressToIndex[addrID] = updateTargetInfo{
|
||||
ch: updateCh,
|
||||
queueIndex: i,
|
||||
}
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
for downloadBatch := range flushCh {
|
||||
logrus.Debugf("Flush batch: %v", len(downloadBatch.batch))
|
||||
for _, res := range downloadBatch.batch {
|
||||
if res.err != nil {
|
||||
if err := vault.AddFailedMessageID(res.messageID); err != nil {
|
||||
logrus.WithError(err).Error("Failed to add failed message ID")
|
||||
@ -327,31 +572,38 @@ func syncMessages(
|
||||
}
|
||||
}
|
||||
|
||||
flushers[res.addressID].push(res.update)
|
||||
targetInfo := addressToIndex[res.addressID]
|
||||
pendingUpdates[targetInfo.queueIndex] = append(pendingUpdates[targetInfo.queueIndex], res.update)
|
||||
}
|
||||
|
||||
var pushedUpdates []imap.Update
|
||||
for _, flusher := range flushers {
|
||||
flusher.flush()
|
||||
pushedUpdates = append(pushedUpdates, flusher.collectPushedUpdates()...)
|
||||
for _, info := range addressToIndex {
|
||||
up := imap.NewMessagesCreated(true, pendingUpdates[info.queueIndex]...)
|
||||
info.ch.Enqueue(up)
|
||||
|
||||
err, ok := up.WaitContext(ctx)
|
||||
if ok && err != nil {
|
||||
flushUpdateCh <- flushUpdate{
|
||||
err: fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err),
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
pendingUpdates[info.queueIndex] = pendingUpdates[info.queueIndex][:0]
|
||||
}
|
||||
|
||||
flushUpdateCh <- flushUpdate{
|
||||
messageID: batch[0].messageID,
|
||||
pushedUpdates: pushedUpdates,
|
||||
batchLen: len(batch),
|
||||
select {
|
||||
case flushUpdateCh <- flushUpdate{
|
||||
messageID: downloadBatch.batch[0].messageID,
|
||||
err: nil,
|
||||
batchLen: len(downloadBatch.batch),
|
||||
}:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
}, logging.Labels{"sync-stage": "flush"})
|
||||
|
||||
for flushUpdate := range flushUpdateCh {
|
||||
for _, up := range flushUpdate.pushedUpdates {
|
||||
err, ok := up.WaitContext(ctx)
|
||||
if ok && err != nil {
|
||||
return fmt.Errorf("failed to apply sync update to gluon %v: %w", up.String(), err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := vault.SetLastMessageID(flushUpdate.messageID); err != nil {
|
||||
return fmt.Errorf("failed to set last synced message ID: %w", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user