mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2025-12-10 04:36:43 +00:00
391 lines
9.5 KiB
Go
391 lines
9.5 KiB
Go
// Copyright (c) 2023 Proton AG
|
|
//
|
|
// This file is part of Proton Mail Bridge.
|
|
//
|
|
// Proton Mail Bridge is free software: you can redistribute it and/or modify
|
|
// it under the terms of the GNU General Public License as published by
|
|
// the Free Software Foundation, either version 3 of the License, or
|
|
// (at your option) any later version.
|
|
//
|
|
// Proton Mail Bridge is distributed in the hope that it will be useful,
|
|
// but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
// GNU General Public License for more details.
|
|
//
|
|
// You should have received a copy of the GNU General Public License
|
|
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
|
|
|
|
package user
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"io"
|
|
"time"
|
|
|
|
"github.com/ProtonMail/gluon/async"
|
|
"github.com/ProtonMail/gluon/logging"
|
|
"github.com/ProtonMail/go-proton-api"
|
|
"github.com/bradenaw/juniper/parallel"
|
|
"github.com/bradenaw/juniper/xslices"
|
|
"github.com/sirupsen/logrus"
|
|
)
|
|
|
|
type downloadRequest struct {
|
|
ids []string
|
|
expectedSize uint64
|
|
err error
|
|
}
|
|
|
|
type downloadedMessageBatch struct {
|
|
batch []proton.FullMessage
|
|
}
|
|
|
|
type MessageDownloader interface {
|
|
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
|
|
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
|
|
}
|
|
|
|
type downloadState int
|
|
|
|
const (
|
|
downloadStateZero downloadState = iota
|
|
downloadStateHasMessage
|
|
downloadStateFinished
|
|
)
|
|
|
|
type downloadResult struct {
|
|
ID string
|
|
State downloadState
|
|
Message proton.FullMessage
|
|
err error
|
|
}
|
|
|
|
func startSyncDownloader(
|
|
ctx context.Context,
|
|
panicHandler async.PanicHandler,
|
|
downloader MessageDownloader,
|
|
cache *SyncDownloadCache,
|
|
downloadCh <-chan downloadRequest,
|
|
syncLimits syncLimits,
|
|
) (<-chan downloadedMessageBatch, <-chan error) {
|
|
buildCh := make(chan downloadedMessageBatch)
|
|
errorCh := make(chan error, syncLimits.MaxParallelDownloads*4)
|
|
|
|
// Goroutine in charge of downloading and building messages in maxBatchSize batches.
|
|
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) {
|
|
defer close(buildCh)
|
|
defer close(errorCh)
|
|
defer func() {
|
|
logrus.Debugf("sync downloader exit")
|
|
}()
|
|
|
|
attachmentDownloader := newAttachmentDownloader(ctx, panicHandler, downloader, cache, syncLimits.MaxParallelDownloads)
|
|
defer attachmentDownloader.close()
|
|
|
|
for request := range downloadCh {
|
|
logrus.Debugf("Download request: %v MB:%v", len(request.ids), toMB(request.expectedSize))
|
|
if request.err != nil {
|
|
errorCh <- request.err
|
|
return
|
|
}
|
|
|
|
result, err := downloadMessagesParallel(ctx, panicHandler, request, downloader, attachmentDownloader, cache, syncLimits.MaxParallelDownloads)
|
|
if err != nil {
|
|
errorCh <- err
|
|
return
|
|
}
|
|
|
|
if ctx.Err() != nil {
|
|
errorCh <- ctx.Err()
|
|
return
|
|
}
|
|
|
|
batch, err := downloadMessagesSequential(ctx, result, downloader, cache, &expCooldown{})
|
|
if err != nil {
|
|
errorCh <- err
|
|
return
|
|
}
|
|
|
|
select {
|
|
case buildCh <- downloadedMessageBatch{
|
|
batch: batch,
|
|
}:
|
|
|
|
case <-ctx.Done():
|
|
return
|
|
}
|
|
}
|
|
}, logging.Labels{"sync-stage": "download"})
|
|
|
|
return buildCh, errorCh
|
|
}
|
|
|
|
type attachmentResult struct {
|
|
attachment []byte
|
|
err error
|
|
}
|
|
|
|
type attachmentJob struct {
|
|
id string
|
|
size int64
|
|
result chan attachmentResult
|
|
}
|
|
|
|
type attachmentDownloader struct {
|
|
workerCh chan attachmentJob
|
|
cancel context.CancelFunc
|
|
}
|
|
|
|
func attachmentWorker(ctx context.Context, downloader MessageDownloader, cache *SyncDownloadCache, work <-chan attachmentJob) {
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case job, ok := <-work:
|
|
if !ok {
|
|
return
|
|
}
|
|
|
|
var result attachmentResult
|
|
if data, ok := cache.GetAttachment(job.id); ok {
|
|
result.attachment = data
|
|
result.err = nil
|
|
} else {
|
|
var b bytes.Buffer
|
|
b.Grow(int(job.size))
|
|
err := downloader.GetAttachmentInto(ctx, job.id, &b)
|
|
result.attachment = b.Bytes()
|
|
result.err = err
|
|
if err == nil {
|
|
cache.StoreAttachment(job.id, result.attachment)
|
|
}
|
|
}
|
|
|
|
select {
|
|
case <-ctx.Done():
|
|
close(job.result)
|
|
return
|
|
case job.result <- result:
|
|
close(job.result)
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func newAttachmentDownloader(
|
|
ctx context.Context,
|
|
panicHandler async.PanicHandler,
|
|
downloader MessageDownloader,
|
|
cache *SyncDownloadCache,
|
|
workerCount int,
|
|
) *attachmentDownloader {
|
|
workerCh := make(chan attachmentJob, (workerCount+2)*workerCount)
|
|
ctx, cancel := context.WithCancel(ctx)
|
|
for i := 0; i < workerCount; i++ {
|
|
workerCh = make(chan attachmentJob)
|
|
async.GoAnnotated(ctx, panicHandler, func(ctx context.Context) { attachmentWorker(ctx, downloader, cache, workerCh) }, logging.Labels{
|
|
"sync": fmt.Sprintf("att-downloader %v", i),
|
|
})
|
|
}
|
|
|
|
return &attachmentDownloader{
|
|
workerCh: workerCh,
|
|
cancel: cancel,
|
|
}
|
|
}
|
|
|
|
func (a *attachmentDownloader) getAttachments(ctx context.Context, attachments []proton.Attachment) ([][]byte, error) {
|
|
resultChs := make([]chan attachmentResult, len(attachments))
|
|
for i, id := range attachments {
|
|
resultChs[i] = make(chan attachmentResult, 1)
|
|
select {
|
|
case a.workerCh <- attachmentJob{id: id.ID, result: resultChs[i], size: id.Size}:
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
}
|
|
}
|
|
|
|
result := make([][]byte, len(attachments))
|
|
var err error
|
|
for i := 0; i < len(attachments); i++ {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case r := <-resultChs[i]:
|
|
if r.err != nil {
|
|
err = fmt.Errorf("failed to get attachment %v: %w", attachments[i], r.err)
|
|
}
|
|
result[i] = r.attachment
|
|
}
|
|
}
|
|
|
|
return result, err
|
|
}
|
|
|
|
func (a *attachmentDownloader) close() {
|
|
a.cancel()
|
|
}
|
|
|
|
func downloadMessagesParallel(
|
|
ctx context.Context,
|
|
panicHandler async.PanicHandler,
|
|
request downloadRequest,
|
|
downloader MessageDownloader,
|
|
attachmentDownloader *attachmentDownloader,
|
|
cache *SyncDownloadCache,
|
|
parallelDownloads int,
|
|
) ([]downloadResult, error) {
|
|
// 1st attempt download everything in parallel
|
|
return parallel.MapContext(ctx, parallelDownloads, request.ids, func(ctx context.Context, id string) (downloadResult, error) {
|
|
defer async.HandlePanic(panicHandler)
|
|
|
|
result := downloadResult{ID: id}
|
|
|
|
v, ok := cache.GetMessage(id)
|
|
if !ok {
|
|
msg, err := downloader.GetMessage(ctx, id)
|
|
if err != nil {
|
|
logrus.WithError(err).WithField("msgID", msg.ID).Error("Failed to download message")
|
|
result.err = err
|
|
return result, nil
|
|
}
|
|
|
|
cache.StoreMessage(msg)
|
|
result.Message.Message = msg
|
|
} else {
|
|
result.Message.Message = v
|
|
}
|
|
|
|
result.State = downloadStateHasMessage
|
|
|
|
attachments, err := attachmentDownloader.getAttachments(ctx, result.Message.Attachments)
|
|
result.Message.AttData = attachments
|
|
|
|
if err != nil {
|
|
logrus.WithError(err).WithField("msgID", id).Error("Failed to download message attachments")
|
|
result.err = err
|
|
return result, nil
|
|
}
|
|
|
|
result.State = downloadStateFinished
|
|
|
|
return result, nil
|
|
})
|
|
}
|
|
|
|
func downloadMessagesSequential(
|
|
ctx context.Context,
|
|
state []downloadResult,
|
|
downloader MessageDownloader,
|
|
cache *SyncDownloadCache,
|
|
cooldown cooldownProvider,
|
|
) ([]proton.FullMessage, error) {
|
|
logrus.Debug("Entering download stage 2")
|
|
var retryList []int
|
|
var shouldWaitBeforeRetry bool
|
|
|
|
for {
|
|
if shouldWaitBeforeRetry {
|
|
time.Sleep(cooldown.GetNextWaitTime())
|
|
}
|
|
|
|
retryList = nil
|
|
shouldWaitBeforeRetry = false
|
|
|
|
for index, s := range state {
|
|
if s.State == downloadStateFinished {
|
|
continue
|
|
}
|
|
|
|
if s.err != nil {
|
|
if is429Error(s.err) {
|
|
logrus.WithField("msg-id", s.ID).Debug("Message download failed due to 429, retrying")
|
|
retryList = append(retryList, index)
|
|
continue
|
|
}
|
|
return nil, s.err
|
|
}
|
|
}
|
|
|
|
if len(retryList) == 0 {
|
|
break
|
|
}
|
|
|
|
for _, i := range retryList {
|
|
st := &state[i]
|
|
if st.State == downloadStateZero {
|
|
message, err := downloader.GetMessage(ctx, st.ID)
|
|
if err != nil {
|
|
logrus.WithField("msg-id", st.ID).WithError(err).Error("failed to download message (429)")
|
|
if is429Error(err) {
|
|
st.err = err
|
|
shouldWaitBeforeRetry = true
|
|
continue
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
cache.StoreMessage(message)
|
|
st.Message.Message = message
|
|
st.State = downloadStateHasMessage
|
|
}
|
|
|
|
if st.Message.AttData == nil && st.Message.NumAttachments != 0 {
|
|
st.Message.AttData = make([][]byte, st.Message.NumAttachments)
|
|
}
|
|
|
|
hasAllAttachments := true
|
|
for i := 0; i < st.Message.NumAttachments; i++ {
|
|
if st.Message.AttData[i] == nil {
|
|
buffer := bytes.Buffer{}
|
|
if err := downloader.GetAttachmentInto(ctx, st.Message.Attachments[i].ID, &buffer); err != nil {
|
|
logrus.WithField("msg-id", st.ID).WithError(err).Errorf("failed to download attachment %v/%v (429)", i+1, len(st.Message.Attachments))
|
|
if is429Error(err) {
|
|
st.err = err
|
|
shouldWaitBeforeRetry = true
|
|
hasAllAttachments = false
|
|
continue
|
|
}
|
|
|
|
return nil, err
|
|
}
|
|
|
|
st.Message.AttData[i] = buffer.Bytes()
|
|
cache.StoreAttachment(st.Message.Attachments[i].ID, st.Message.AttData[i])
|
|
}
|
|
}
|
|
|
|
if hasAllAttachments {
|
|
st.State = downloadStateFinished
|
|
}
|
|
}
|
|
}
|
|
|
|
logrus.Debug("All message downloaded successfully")
|
|
return xslices.Map(state, func(s downloadResult) proton.FullMessage {
|
|
return s.Message
|
|
}), nil
|
|
}
|
|
|
|
func is429Error(err error) bool {
|
|
var apiErr *proton.APIError
|
|
if errors.As(err, &apiErr) {
|
|
return apiErr.Status == 429
|
|
}
|
|
|
|
return false
|
|
}
|
|
|
|
func is429Or5XXError(err error) bool {
|
|
var apiErr *proton.APIError
|
|
if errors.As(err, &apiErr) {
|
|
return apiErr.Status == 429 || apiErr.Status >= 500
|
|
}
|
|
|
|
return false
|
|
}
|