fix(GODT-2829): Sync Service fixes

Fix tracking of child jobs. The build stage splits the incoming work
even further, but this was not reflected in the wait group counter. This
also fixes an issue where the cache was cleared to late.

Add more debug info for analysis.

Refactor sync state interface in order to have persistent sync rate.
This commit is contained in:
Leander Beernaert
2023-08-25 15:01:03 +02:00
parent 78f7cbdc79
commit aa77a67a1c
16 changed files with 221 additions and 189 deletions

View File

@ -299,11 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres
> internal/events/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
> internal/services/useridentity/mocks/mocks.go
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice" -package syncservice github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice \
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
> tmp
mv tmp internal/services/sync/mocks_test.go
mv tmp internal/services/syncservice/mocks_test.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report

View File

@ -29,7 +29,6 @@ type APIClient interface {
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
GetMessageIDs(ctx context.Context, afterID string) ([]string, error)
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)

View File

@ -117,7 +117,11 @@ func (t *Handler) Execute(
}
t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
t.syncFinishedCh <- err
select {
case <-ctx.Done():
return
case t.syncFinishedCh <- err:
}
})
}
@ -179,8 +183,12 @@ func (t *Handler) run(ctx context.Context,
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
return fmt.Errorf("failed to store message count: %w", err)
}
syncStatus.TotalMessageCount = totalMessageCount
}
syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount)
if !syncStatus.HasMessages {
t.log.Info("Syncing messages")
@ -198,6 +206,11 @@ func (t *Handler) run(ctx context.Context,
t.log,
)
stageContext.metadataFetched = syncStatus.NumSyncedMessages
stageContext.totalMessageCount = syncStatus.TotalMessageCount
defer stageContext.Close()
t.regulator.Sync(ctx, stageContext)
// Wait on reply

View File

@ -57,6 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) {
})
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
@ -125,6 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) {
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
}
{
@ -170,6 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) {
})
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
}
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
@ -219,6 +222,8 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) {
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
{
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{

View File

@ -28,8 +28,8 @@ import (
)
type StateProvider interface {
AddFailedMessageID(context.Context, string) error
RemFailedMessageID(context.Context, string) error
AddFailedMessageID(context.Context, ...string) error
RemFailedMessageID(context.Context, ...string) error
GetSyncStatus(context.Context) (Status, error)
ClearSyncStatus(context.Context) error
SetHasLabels(context.Context, bool) error
@ -85,4 +85,5 @@ type Reporter interface {
OnFinished(ctx context.Context)
OnError(ctx context.Context, err error)
OnProgress(ctx context.Context, delta int64)
InitializeProgressCounter(ctx context.Context, current int64, total int64)
}

View File

@ -24,6 +24,7 @@ import (
"sync"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/sirupsen/logrus"
)
@ -54,6 +55,9 @@ type Job struct {
panicHandler async.PanicHandler
downloadCache *DownloadCache
metadataFetched int64
totalMessageCount int64
}
func NewJob(ctx context.Context,
@ -178,6 +182,36 @@ func (s *childJob) userID() string {
return s.job.userID
}
func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob {
numChunks := len(chunks)
if numChunks == 1 {
return []childJob{*s}
}
result := make([]childJob, numChunks)
for i := 0; i < numChunks-1; i++ {
result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i])))
collectIDs(&result[i], chunks[i])
}
result[numChunks-1] = *s
collectIDs(&result[numChunks-1], chunks[numChunks-1])
return result
}
func collectIDs(j *childJob, msgs []proton.FullMessage) {
j.cachedAttachmentIDs = make([]string, 0, len(msgs))
j.cachedMessageIDs = make([]string, 0, len(msgs))
for _, msg := range msgs {
j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID)
for _, attach := range msg.Attachments {
j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID)
}
}
}
func (s *childJob) onFinished(ctx context.Context) {
s.job.log.Infof("Child job finished")
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)

View File

@ -1,7 +1,7 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
// Package sync is a generated GoMock package.
// Package syncservice is a generated GoMock package.
package syncservice
import (
@ -53,22 +53,6 @@ func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(ApplyRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0)
}
// MockBuildStageInput is a mock of BuildStageInput interface.
type MockBuildStageInput struct {
ctrl *gomock.Controller
@ -107,22 +91,6 @@ func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(BuildRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0)
}
// MockBuildStageOutput is a mock of BuildStageOutput interface.
type MockBuildStageOutput struct {
ctrl *gomock.Controller
@ -208,22 +176,6 @@ func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(DownloadRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0)
}
// MockDownloadStageOutput is a mock of DownloadStageOutput interface.
type MockDownloadStageOutput struct {
ctrl *gomock.Controller
@ -309,22 +261,6 @@ func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
}
// TryConsume mocks base method.
func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(*Job)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0)
}
// MockMetadataStageOutput is a mock of MetadataStageOutput interface.
type MockMetadataStageOutput struct {
ctrl *gomock.Controller
@ -396,17 +332,22 @@ func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder {
}
// AddFailedMessageID mocks base method.
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error {
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 ...string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1)
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "AddFailedMessageID", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// AddFailedMessageID indicates an expected call of AddFailedMessageID.
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1)
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), varargs...)
}
// ClearSyncStatus mocks base method.
@ -439,17 +380,22 @@ func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock
}
// RemFailedMessageID mocks base method.
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error {
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 ...string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1)
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "RemFailedMessageID", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// RemFailedMessageID indicates an expected call of RemFailedMessageID.
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1)
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), varargs...)
}
// SetHasLabels mocks base method.
@ -777,21 +723,6 @@ func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
}
// GetMessageIDs mocks base method.
func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessageIDs indicates an expected call of GetMessageIDs.
func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1)
}
// GetMessageMetadataPage mocks base method.
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
m.ctrl.T.Helper()
@ -830,6 +761,18 @@ func (m *MockReporter) EXPECT() *MockReporterMockRecorder {
return m.recorder
}
// InitializeProgressCounter mocks base method.
func (m *MockReporter) InitializeProgressCounter(arg0 context.Context, arg1, arg2 int64) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "InitializeProgressCounter", arg0, arg1, arg2)
}
// InitializeProgressCounter indicates an expected call of InitializeProgressCounter.
func (mr *MockReporterMockRecorder) InitializeProgressCounter(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeProgressCounter", reflect.TypeOf((*MockReporter)(nil).InitializeProgressCounter), arg0, arg1, arg2)
}
// OnError mocks base method.
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
m.ctrl.T.Helper()

View File

@ -48,7 +48,7 @@ func NewService(reporter reporter.Reporter,
return &Service{
limits: limits,
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem),
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem, panicHandler),
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
applyStage: NewApplyStage(applyCh),

View File

@ -22,6 +22,7 @@ import (
"errors"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/sirupsen/logrus"
)
@ -44,7 +45,15 @@ func NewApplyStage(input ApplyStageInput) *ApplyStage {
}
func (a *ApplyStage) Run(group *async.Group) {
group.Once(a.run)
group.Once(func(ctx context.Context) {
logging.DoAnnotated(
ctx,
func(ctx context.Context) {
a.run(ctx)
},
logging.Labels{"sync-stage": "apply"},
)
})
}
func (a *ApplyStage) run(ctx context.Context) {

View File

@ -24,6 +24,7 @@ import (
"runtime"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto"
@ -70,7 +71,15 @@ func NewBuildStage(
}
func (b *BuildStage) Run(group *async.Group) {
group.Once(b.run)
group.Once(func(ctx context.Context) {
logging.DoAnnotated(
ctx,
func(ctx context.Context) {
b.run(ctx)
},
logging.Labels{"sync-stage": "build"},
)
})
}
func (b *BuildStage) run(ctx context.Context) {
@ -94,7 +103,21 @@ func (b *BuildStage) run(ctx context.Context) {
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
for _, chunk := range chunks {
// This stage will split our existing job into many smaller bits. We need to update the Parent Job so
// that it correctly tracks the lifetime of extra jobs. Additionally, we also need to make sure
// that only the last chunk contains the metadata to clear the cache.
chunkedJobs := req.chunkDivide(chunks)
for idx, chunk := range chunks {
if chunkedJobs[idx].checkCancelled() {
// Cancel all other chunks.
for i := idx + 1; i < len(chunkedJobs); i++ {
chunkedJobs[i].checkCancelled()
}
return nil
}
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
defer async.HandlePanic(b.panicHandler)
@ -135,21 +158,29 @@ func (b *BuildStage) run(ctx context.Context) {
return BuildResult{}, nil
}
if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil {
req.job.log.WithError(err).Error("Failed to remove failed message ID")
}
return res, nil
})
if err != nil {
return err
}
success := xslices.Filter(result, func(t BuildResult) bool {
return t.Update != nil
})
if len(success) > 0 {
successIDs := xslices.Map(success, func(t BuildResult) string {
return t.MessageID
})
if err := req.job.state.RemFailedMessageID(req.getContext(), successIDs...); err != nil {
req.job.log.WithError(err).Error("Failed to remove failed message ID")
}
}
b.output.Produce(ctx, ApplyRequest{
childJob: req.childJob,
messages: xslices.Filter(result, func(t BuildResult) bool {
return t.Update != nil
}),
childJob: chunkedJobs[idx],
messages: success,
})
}

View File

@ -153,7 +153,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
buildError := errors.New("it failed")
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
"userID": "u",
"messageID": "MSG",
@ -204,7 +204,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
childJob := tj.job.newChildJob("f", 10)
tj.job.end()
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
"userID": "u",
"messageID": "MSG",

View File

@ -112,8 +112,11 @@ func (d *DownloadStage) run(ctx context.Context) {
}
var attData [][]byte
if msg.NumAttachments > 0 {
attData = make([][]byte, msg.NumAttachments)
numAttachments := len(msg.Attachments)
if numAttachments > 0 {
attData = make([][]byte, numAttachments)
}
return proton.FullMessage{Message: msg, AttData: attData}, nil
@ -139,7 +142,8 @@ func (d *DownloadStage) run(ctx context.Context) {
attachmentIDs := make([]string, 0, len(result))
for msgIdx, v := range result {
for attIdx := 0; attIdx < v.NumAttachments; attIdx++ {
numAttachments := len(v.Attachments)
for attIdx := 0; attIdx < numAttachments; attIdx++ {
attachmentIndices = append(attachmentIndices, attachmentMeta{
msgIdx: msgIdx,
attIdx: attIdx,

View File

@ -333,8 +333,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
MessageMetadata: proton.MessageMetadata{
ID: "msg",
NumAttachments: 1,
ID: "msg",
},
Header: "",
ParsedHeaders: nil,
@ -436,7 +435,7 @@ func buildDownloadStageAttachments(msg *proton.FullMessage, index int) {
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
msg.Attachments = make([]proton.Attachment, count)
msg.AttData = make([][]byte, count)
msg.NumAttachments = count
for i := 0; i < count; i++ {
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
msg.Attachments[i] = proton.Attachment{

View File

@ -19,11 +19,12 @@ package syncservice
import (
"context"
"errors"
"github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus"
)
@ -37,80 +38,91 @@ type MetadataStage struct {
output MetadataStageOutput
input MetadataStageInput
maxDownloadMem uint64
jobs []*metadataIterator
log *logrus.Entry
panicHandler async.PanicHandler
}
func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage {
return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")}
func NewMetadataStage(
input MetadataStageInput,
output MetadataStageOutput,
maxDownloadMem uint64,
panicHandler async.PanicHandler,
) *MetadataStage {
return &MetadataStage{
input: input,
output: output,
maxDownloadMem: maxDownloadMem,
log: logrus.WithField("sync-stage", "metadata"),
panicHandler: panicHandler,
}
}
const MetadataPageSize = 150
const MetadataMaxMessages = 250
func (m MetadataStage) Run(group *async.Group) {
func (m *MetadataStage) Run(group *async.Group) {
group.Once(func(ctx context.Context) {
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
logging.DoAnnotated(
ctx,
func(ctx context.Context) {
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
},
logging.Labels{"sync-stage": "metadata"},
)
})
}
func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
defer m.output.Close()
group := async.NewGroup(ctx, m.panicHandler)
defer group.CancelAndWait()
for {
if ctx.Err() != nil {
return
}
// Check if new job has been submitted
job, ok, err := m.input.TryConsume(ctx)
job, err := m.input.Consume(ctx)
if err != nil {
m.log.WithError(err).Error("Error trying to retrieve more work")
if !(errors.Is(err, context.Canceled) || errors.Is(err, ErrNoMoreInput)) {
m.log.WithError(err).Error("Error trying to retrieve more work")
}
return
}
if ok {
job.begin()
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
if err != nil {
job.onError(err)
continue
}
m.jobs = append(m.jobs, state)
job.begin()
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
if err != nil {
job.onError(err)
continue
}
// Iterate over all jobs and produce work.
for i := 0; i < len(m.jobs); {
job := m.jobs[i]
group.Once(func(ctx context.Context) {
for {
if state.stage.ctx.Err() != nil {
state.stage.end()
return
}
// If the job's context has been cancelled, remove from the list.
if job.stage.ctx.Err() != nil {
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
job.stage.end()
continue
// Check for more work.
output, hasMore, err := state.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
if err != nil {
state.stage.onError(err)
return
}
// If there is actually more work, push it down the pipeline.
if len(output.ids) != 0 {
state.stage.metadataFetched += int64(len(output.ids))
job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount)
m.output.Produce(ctx, output)
}
// If this job has no more work left, signal completion.
if !hasMore {
state.stage.end()
return
}
}
// Check for more work.
output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
if err != nil {
job.stage.onError(err)
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
continue
}
// If there is actually more work, push it down the pipeline.
if len(output.ids) != 0 {
m.output.Produce(ctx, output)
}
// If this job has no more work left, signal completion.
if !hasMore {
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
job.stage.end()
continue
}
i++
}
})
}
}

View File

@ -48,7 +48,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
numMessages := 50
messageSize := 100
@ -86,7 +86,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
go func() {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
@ -135,7 +135,7 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
numMessages := 50
messageSize := 100

View File

@ -31,7 +31,6 @@ var ErrNoMoreInput = errors.New("no more input")
type StageInputConsumer[T any] interface {
Consume(ctx context.Context) (T, error)
TryConsume(ctx context.Context) (T, bool, error)
}
type ChannelConsumerProducer[T any] struct {
@ -66,20 +65,3 @@ func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) {
return t, nil
}
}
func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) {
select {
case <-ctx.Done():
var t T
return t, false, ctx.Err()
case t, ok := <-c.ch:
if !ok {
return t, false, ErrNoMoreInput
}
return t, true, nil
default:
var t T
return t, false, nil
}
}