Files
proton-bridge/internal/services/syncservice/job.go

310 lines
7.4 KiB
Go

// Copyright (c) 2024 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package syncservice
import (
"context"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/sirupsen/logrus"
)
// Job represents a unit of work that will travel down the sync pipeline. The job will be split up into child jobs
// for each batch. The parent job (this) will then wait until all the children have finished executing. Execution can
// terminate by either:
// * Completing the pipeline successfully
// * Context Cancellation
// * Errors
// On error, or context cancellation all child jobs are cancelled.
type Job struct {
ctx context.Context
cancel func()
client APIClient
state StateProvider
userID string
labels LabelMap
messageBuilder MessageBuilder
updateApplier UpdateApplier
syncReporter Reporter
log *logrus.Entry
jw *jobWaiter
panicHandler async.PanicHandler
downloadCache *DownloadCache
metadataFetched int64
totalMessageCount int64
}
func NewJob(ctx context.Context,
client APIClient,
userID string,
labels LabelMap,
messageBuilder MessageBuilder,
updateApplier UpdateApplier,
syncReporter Reporter,
state StateProvider,
panicHandler async.PanicHandler,
cache *DownloadCache,
log *logrus.Entry,
) *Job {
ctx, cancel := context.WithCancel(ctx)
j := &Job{
ctx: ctx,
client: client,
userID: userID,
cancel: cancel,
state: state,
log: log,
labels: labels,
messageBuilder: messageBuilder,
updateApplier: updateApplier,
syncReporter: syncReporter,
panicHandler: panicHandler,
downloadCache: cache,
jw: newJobWaiter(log.WithField("sync-job", "waiter"), panicHandler),
}
j.jw.begin()
return j
}
func (j *Job) close() {
j.jw.close()
}
func (j *Job) onError(err error) {
defer j.jw.onTaskFinished(err)
j.cancel()
}
func (j *Job) onStageCompleted(ctx context.Context, count int64) {
j.syncReporter.OnProgress(ctx, count)
}
func (j *Job) onJobFinished(ctx context.Context, lastMessageID string, count int64) {
if err := j.state.SetLastMessageID(ctx, lastMessageID, count); err != nil {
j.log.WithError(err).Error("Failed to store last synced message id")
j.onError(err)
return
}
// j.onError() also calls j.jw.onTaskFinished().
defer j.jw.onTaskFinished(nil)
j.syncReporter.OnProgress(ctx, count)
}
// begin is expected to be called once the job enters the pipeline.
func (j *Job) begin() {
j.log.Info("Job started")
}
// end is expected to be called once the job has no further work left.
func (j *Job) end() {
j.log.Info("Job finished")
j.jw.onTaskFinished(nil)
}
// waitAndClose waits until the job has finished, the context got cancelled or an error occurred.
func (j *Job) waitAndClose(ctx context.Context) error {
defer j.close()
select {
case <-ctx.Done():
<-j.jw.doneCh
return ctx.Err()
case e := <-j.jw.doneCh:
return e
}
}
func (j *Job) newChildJob(messageID string, messageCount int64) childJob {
j.log.Infof("Creating new child job")
j.jw.onTaskCreated()
return childJob{job: j, lastMessageID: messageID, messageCount: messageCount}
}
// childJob represents a batch of work that goes down the pipeline. It keeps track of the message ID that is in the
// batch and the number of messages in the batch.
type childJob struct {
job *Job
lastMessageID string
messageCount int64
cachedMessageIDs []string
cachedAttachmentIDs []string
}
func (s *childJob) onError(err error) {
s.job.log.WithError(err).Info("Child job ran into error")
s.job.onError(err)
}
func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob {
numChunks := len(chunks)
if numChunks == 1 {
return []childJob{*s}
}
result := make([]childJob, numChunks)
for i := 0; i < numChunks-1; i++ {
result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i])))
collectIDs(&result[i], chunks[i])
}
result[numChunks-1] = *s
collectIDs(&result[numChunks-1], chunks[numChunks-1])
return result
}
func collectIDs(j *childJob, msgs []proton.FullMessage) {
j.cachedAttachmentIDs = make([]string, 0, len(msgs))
j.cachedMessageIDs = make([]string, 0, len(msgs))
for _, msg := range msgs {
j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID)
for _, attach := range msg.Attachments {
j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID)
}
}
}
func (s *childJob) onFinished(ctx context.Context) {
s.job.log.Infof("Child job finished")
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)
s.job.downloadCache.DeleteMessages(s.cachedMessageIDs...)
s.job.downloadCache.DeleteAttachments(s.cachedAttachmentIDs...)
}
func (s *childJob) onStageCompleted(ctx context.Context) {
s.job.onStageCompleted(ctx, s.messageCount)
}
func (s *childJob) checkCancelled() bool {
err := s.job.ctx.Err()
if err != nil {
s.job.log.Infof("Child job exit due to context cancelled")
s.job.jw.onTaskFinished(err)
return true
}
return false
}
func (s *childJob) getContext() context.Context {
return s.job.ctx
}
type JobWaiterMessage int
const (
JobWaiterMessageCreated JobWaiterMessage = iota
JobWaiterMessageFinished
)
type jobWaiterMessagePair struct {
m JobWaiterMessage
err error
}
// jobWaiter is meant to be used to track ongoing sync batches. Once all the child jobs
// have completed, the first recorded error (if any) will be written to doneCh and then this
// channel will be closed.
type jobWaiter struct {
ch chan jobWaiterMessagePair
doneCh chan error
log *logrus.Entry
panicHandler async.PanicHandler
}
func newJobWaiter(log *logrus.Entry, panicHandler async.PanicHandler) *jobWaiter {
return &jobWaiter{
ch: make(chan jobWaiterMessagePair),
doneCh: make(chan error, 2),
log: log,
panicHandler: panicHandler,
}
}
func (j *jobWaiter) close() {
close(j.ch)
}
func (j *jobWaiter) sendMessage(m JobWaiterMessage, err error) {
j.ch <- jobWaiterMessagePair{
m: m,
err: err,
}
}
func (j *jobWaiter) onTaskFinished(err error) {
j.sendMessage(JobWaiterMessageFinished, err)
}
func (j *jobWaiter) onTaskCreated() {
j.sendMessage(JobWaiterMessageCreated, nil)
}
func (j *jobWaiter) begin() {
go func() {
defer async.HandlePanic(j.panicHandler)
total := 1
var err error
defer func() {
j.doneCh <- err
close(j.doneCh)
}()
for {
m, ok := <-j.ch
if !ok {
return
}
switch m.m {
case JobWaiterMessageCreated:
total++
case JobWaiterMessageFinished:
total--
if m.err != nil && err == nil {
err = m.err
}
default:
j.log.Errorf("Unknown message type: %v", m.m)
continue
}
if total <= 0 {
if total < 0 {
logrus.Errorf("Child count less than 0, shouldn't happen...")
}
j.log.Info("All child jobs completed")
return
}
}
}()
}