diff --git a/internal/user/sync.go b/internal/user/sync.go index ccf7b06b..af432080 100644 --- a/internal/user/sync.go +++ b/internal/user/sync.go @@ -415,6 +415,9 @@ func syncMessages( logrus.Debugf("sync downloader exit") }() + attachmentDownloader := newAttachmentDownloader(ctx, client, maxParallelDownloads) + defer attachmentDownloader.close() + for request := range downloadCh { logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize)) if request.err != nil { @@ -435,25 +438,13 @@ func syncMessages( return proton.FullMessage{}, err } - var attachmentSize int64 - for _, a := range msg.Attachments { - attachmentSize += a.Size - } - - // allocate attachment data. - result.AttData = make([][]byte, len(msg.Attachments)) - - for i, a := range msg.Attachments { - var buffer bytes.Buffer - buffer.Grow(int(a.Size)) - if err := client.GetAttachmentInto(ctx, a.ID, &buffer); err != nil { - return proton.FullMessage{}, err - } - - result.AttData[i] = buffer.Bytes() + attachments, err := attachmentDownloader.getAttachments(ctx, msg.Attachments) + if err != nil { + return proton.FullMessage{}, err } result.Message = msg + result.AttData = attachments return result, nil }) @@ -738,3 +729,86 @@ func wantLabels(apiLabels map[string]proton.Label, labelIDs []string) []string { return wantLabel(apiLabels[labelID]) }) } + +type attachmentResult struct { + attachment []byte + err error +} + +type attachmentJob struct { + id string + size int64 + result chan attachmentResult +} + +type attachmentDownloader struct { + workerCh chan attachmentJob + cancel context.CancelFunc +} + +func attachmentWorker(ctx context.Context, client *proton.Client, work <-chan attachmentJob) { + for { + select { + case <-ctx.Done(): + return + case job, ok := <-work: + if !ok { + return + } + var b bytes.Buffer + b.Grow(int(job.size)) + err := client.GetAttachmentInto(ctx, job.id, &b) + select { + case <-ctx.Done(): + close(job.result) + return + case job.result <- attachmentResult{attachment: b.Bytes(), err: err}: + close(job.result) + } + } + } +} + +func newAttachmentDownloader(ctx context.Context, client *proton.Client, workerCount int) *attachmentDownloader { + workerCh := make(chan attachmentJob, (workerCount+2)*workerCount) + ctx, cancel := context.WithCancel(ctx) + for i := 0; i < workerCount; i++ { + workerCh = make(chan attachmentJob) + logging.GoAnnotated(ctx, func(ctx context.Context) { attachmentWorker(ctx, client, workerCh) }, logging.Labels{ + "sync": fmt.Sprintf("att-downloader %v", i), + }) + } + + return &attachmentDownloader{ + workerCh: workerCh, + cancel: cancel, + } +} + +func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) { + resultChs := make([]chan attachmentResult, len(attachments)) + for i, id := range attachments { + resultChs[i] = make(chan attachmentResult, 1) + a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size} + } + + result := make([][]byte, len(attachments)) + var err error + for i := 0; i < len(attachments); i++ { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case r := <-resultChs[i]: + if r.err != nil { + err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err) + } + result[i] = r.attachment + } + } + + return result, err +} + +func (a *attachmentDownloader) close() { + a.cancel() +}