mirror of
https://github.com/ProtonMail/proton-bridge.git
synced 2026-02-04 00:08:33 +00:00
fix(GODT-2829): Sync Service fixes
Fix tracking of child jobs. The build stage splits the incoming work even further, but this was not reflected in the wait group counter. This also fixes an issue where the cache was cleared to late. Add more debug info for analysis. Refactor sync state interface in order to have persistent sync rate.
This commit is contained in:
4
Makefile
4
Makefile
@ -299,11 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres
|
|||||||
> internal/events/mocks/mocks.go
|
> internal/events/mocks/mocks.go
|
||||||
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
|
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
|
||||||
> internal/services/useridentity/mocks/mocks.go
|
> internal/services/useridentity/mocks/mocks.go
|
||||||
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \
|
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice" -package syncservice github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice \
|
||||||
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
|
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
|
||||||
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
|
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
|
||||||
> tmp
|
> tmp
|
||||||
mv tmp internal/services/sync/mocks_test.go
|
mv tmp internal/services/syncservice/mocks_test.go
|
||||||
|
|
||||||
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report
|
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report
|
||||||
|
|
||||||
|
|||||||
@ -29,7 +29,6 @@ type APIClient interface {
|
|||||||
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
|
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
|
||||||
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
|
GetMessage(ctx context.Context, messageID string) (proton.Message, error)
|
||||||
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
|
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
|
||||||
GetMessageIDs(ctx context.Context, afterID string) ([]string, error)
|
|
||||||
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
|
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
|
||||||
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
|
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
|
||||||
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)
|
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)
|
||||||
|
|||||||
@ -117,7 +117,11 @@ func (t *Handler) Execute(
|
|||||||
}
|
}
|
||||||
|
|
||||||
t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
|
t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
|
||||||
t.syncFinishedCh <- err
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case t.syncFinishedCh <- err:
|
||||||
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -179,8 +183,12 @@ func (t *Handler) run(ctx context.Context,
|
|||||||
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
|
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
|
||||||
return fmt.Errorf("failed to store message count: %w", err)
|
return fmt.Errorf("failed to store message count: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
syncStatus.TotalMessageCount = totalMessageCount
|
||||||
}
|
}
|
||||||
|
|
||||||
|
syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount)
|
||||||
|
|
||||||
if !syncStatus.HasMessages {
|
if !syncStatus.HasMessages {
|
||||||
t.log.Info("Syncing messages")
|
t.log.Info("Syncing messages")
|
||||||
|
|
||||||
@ -198,6 +206,11 @@ func (t *Handler) run(ctx context.Context,
|
|||||||
t.log,
|
t.log,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
stageContext.metadataFetched = syncStatus.NumSyncedMessages
|
||||||
|
stageContext.totalMessageCount = syncStatus.TotalMessageCount
|
||||||
|
|
||||||
|
defer stageContext.Close()
|
||||||
|
|
||||||
t.regulator.Sync(ctx, stageContext)
|
t.regulator.Sync(ctx, stageContext)
|
||||||
|
|
||||||
// Wait on reply
|
// Wait on reply
|
||||||
|
|||||||
@ -57,6 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) {
|
|||||||
})
|
})
|
||||||
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
|
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
|
||||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||||
|
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||||
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||||
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
|
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||||
@ -125,6 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) {
|
|||||||
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
|
||||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||||
|
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||||
}
|
}
|
||||||
|
|
||||||
{
|
{
|
||||||
@ -170,6 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) {
|
|||||||
})
|
})
|
||||||
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
|
||||||
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
|
||||||
|
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||||
}
|
}
|
||||||
|
|
||||||
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
|
||||||
@ -219,6 +222,8 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) {
|
|||||||
|
|
||||||
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
|
||||||
|
|
||||||
|
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
|
||||||
|
|
||||||
{
|
{
|
||||||
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
|
||||||
return Status{
|
return Status{
|
||||||
|
|||||||
@ -28,8 +28,8 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type StateProvider interface {
|
type StateProvider interface {
|
||||||
AddFailedMessageID(context.Context, string) error
|
AddFailedMessageID(context.Context, ...string) error
|
||||||
RemFailedMessageID(context.Context, string) error
|
RemFailedMessageID(context.Context, ...string) error
|
||||||
GetSyncStatus(context.Context) (Status, error)
|
GetSyncStatus(context.Context) (Status, error)
|
||||||
ClearSyncStatus(context.Context) error
|
ClearSyncStatus(context.Context) error
|
||||||
SetHasLabels(context.Context, bool) error
|
SetHasLabels(context.Context, bool) error
|
||||||
@ -85,4 +85,5 @@ type Reporter interface {
|
|||||||
OnFinished(ctx context.Context)
|
OnFinished(ctx context.Context)
|
||||||
OnError(ctx context.Context, err error)
|
OnError(ctx context.Context, err error)
|
||||||
OnProgress(ctx context.Context, delta int64)
|
OnProgress(ctx context.Context, delta int64)
|
||||||
|
InitializeProgressCounter(ctx context.Context, current int64, total int64)
|
||||||
}
|
}
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -54,6 +55,9 @@ type Job struct {
|
|||||||
|
|
||||||
panicHandler async.PanicHandler
|
panicHandler async.PanicHandler
|
||||||
downloadCache *DownloadCache
|
downloadCache *DownloadCache
|
||||||
|
|
||||||
|
metadataFetched int64
|
||||||
|
totalMessageCount int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewJob(ctx context.Context,
|
func NewJob(ctx context.Context,
|
||||||
@ -178,6 +182,36 @@ func (s *childJob) userID() string {
|
|||||||
return s.job.userID
|
return s.job.userID
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob {
|
||||||
|
numChunks := len(chunks)
|
||||||
|
|
||||||
|
if numChunks == 1 {
|
||||||
|
return []childJob{*s}
|
||||||
|
}
|
||||||
|
|
||||||
|
result := make([]childJob, numChunks)
|
||||||
|
for i := 0; i < numChunks-1; i++ {
|
||||||
|
result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i])))
|
||||||
|
collectIDs(&result[i], chunks[i])
|
||||||
|
}
|
||||||
|
|
||||||
|
result[numChunks-1] = *s
|
||||||
|
collectIDs(&result[numChunks-1], chunks[numChunks-1])
|
||||||
|
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectIDs(j *childJob, msgs []proton.FullMessage) {
|
||||||
|
j.cachedAttachmentIDs = make([]string, 0, len(msgs))
|
||||||
|
j.cachedMessageIDs = make([]string, 0, len(msgs))
|
||||||
|
for _, msg := range msgs {
|
||||||
|
j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID)
|
||||||
|
for _, attach := range msg.Attachments {
|
||||||
|
j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *childJob) onFinished(ctx context.Context) {
|
func (s *childJob) onFinished(ctx context.Context) {
|
||||||
s.job.log.Infof("Child job finished")
|
s.job.log.Infof("Child job finished")
|
||||||
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)
|
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)
|
||||||
|
|||||||
@ -1,7 +1,7 @@
|
|||||||
// Code generated by MockGen. DO NOT EDIT.
|
// Code generated by MockGen. DO NOT EDIT.
|
||||||
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
|
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
|
||||||
|
|
||||||
// Package sync is a generated GoMock package.
|
// Package syncservice is a generated GoMock package.
|
||||||
package syncservice
|
package syncservice
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@ -53,22 +53,6 @@ func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TryConsume mocks base method.
|
|
||||||
func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
|
||||||
ret0, _ := ret[0].(ApplyRequest)
|
|
||||||
ret1, _ := ret[1].(bool)
|
|
||||||
ret2, _ := ret[2].(error)
|
|
||||||
return ret0, ret1, ret2
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryConsume indicates an expected call of TryConsume.
|
|
||||||
func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockBuildStageInput is a mock of BuildStageInput interface.
|
// MockBuildStageInput is a mock of BuildStageInput interface.
|
||||||
type MockBuildStageInput struct {
|
type MockBuildStageInput struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
@ -107,22 +91,6 @@ func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TryConsume mocks base method.
|
|
||||||
func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
|
||||||
ret0, _ := ret[0].(BuildRequest)
|
|
||||||
ret1, _ := ret[1].(bool)
|
|
||||||
ret2, _ := ret[2].(error)
|
|
||||||
return ret0, ret1, ret2
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryConsume indicates an expected call of TryConsume.
|
|
||||||
func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockBuildStageOutput is a mock of BuildStageOutput interface.
|
// MockBuildStageOutput is a mock of BuildStageOutput interface.
|
||||||
type MockBuildStageOutput struct {
|
type MockBuildStageOutput struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
@ -208,22 +176,6 @@ func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TryConsume mocks base method.
|
|
||||||
func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
|
||||||
ret0, _ := ret[0].(DownloadRequest)
|
|
||||||
ret1, _ := ret[1].(bool)
|
|
||||||
ret2, _ := ret[2].(error)
|
|
||||||
return ret0, ret1, ret2
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryConsume indicates an expected call of TryConsume.
|
|
||||||
func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockDownloadStageOutput is a mock of DownloadStageOutput interface.
|
// MockDownloadStageOutput is a mock of DownloadStageOutput interface.
|
||||||
type MockDownloadStageOutput struct {
|
type MockDownloadStageOutput struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
@ -309,22 +261,6 @@ func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TryConsume mocks base method.
|
|
||||||
func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "TryConsume", arg0)
|
|
||||||
ret0, _ := ret[0].(*Job)
|
|
||||||
ret1, _ := ret[1].(bool)
|
|
||||||
ret2, _ := ret[2].(error)
|
|
||||||
return ret0, ret1, ret2
|
|
||||||
}
|
|
||||||
|
|
||||||
// TryConsume indicates an expected call of TryConsume.
|
|
||||||
func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MockMetadataStageOutput is a mock of MetadataStageOutput interface.
|
// MockMetadataStageOutput is a mock of MetadataStageOutput interface.
|
||||||
type MockMetadataStageOutput struct {
|
type MockMetadataStageOutput struct {
|
||||||
ctrl *gomock.Controller
|
ctrl *gomock.Controller
|
||||||
@ -396,17 +332,22 @@ func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddFailedMessageID mocks base method.
|
// AddFailedMessageID mocks base method.
|
||||||
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error {
|
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 ...string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1)
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
ret := m.ctrl.Call(m, "AddFailedMessageID", varargs...)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddFailedMessageID indicates an expected call of AddFailedMessageID.
|
// AddFailedMessageID indicates an expected call of AddFailedMessageID.
|
||||||
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1)
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), varargs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClearSyncStatus mocks base method.
|
// ClearSyncStatus mocks base method.
|
||||||
@ -439,17 +380,22 @@ func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemFailedMessageID mocks base method.
|
// RemFailedMessageID mocks base method.
|
||||||
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error {
|
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 ...string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1)
|
varargs := []interface{}{arg0}
|
||||||
|
for _, a := range arg1 {
|
||||||
|
varargs = append(varargs, a)
|
||||||
|
}
|
||||||
|
ret := m.ctrl.Call(m, "RemFailedMessageID", varargs...)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemFailedMessageID indicates an expected call of RemFailedMessageID.
|
// RemFailedMessageID indicates an expected call of RemFailedMessageID.
|
||||||
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call {
|
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1)
|
varargs := append([]interface{}{arg0}, arg1...)
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), varargs...)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetHasLabels mocks base method.
|
// SetHasLabels mocks base method.
|
||||||
@ -777,21 +723,6 @@ func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetMessageIDs mocks base method.
|
|
||||||
func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1)
|
|
||||||
ret0, _ := ret[0].([]string)
|
|
||||||
ret1, _ := ret[1].(error)
|
|
||||||
return ret0, ret1
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMessageIDs indicates an expected call of GetMessageIDs.
|
|
||||||
func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1)
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetMessageMetadataPage mocks base method.
|
// GetMessageMetadataPage mocks base method.
|
||||||
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
|
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
@ -830,6 +761,18 @@ func (m *MockReporter) EXPECT() *MockReporterMockRecorder {
|
|||||||
return m.recorder
|
return m.recorder
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeProgressCounter mocks base method.
|
||||||
|
func (m *MockReporter) InitializeProgressCounter(arg0 context.Context, arg1, arg2 int64) {
|
||||||
|
m.ctrl.T.Helper()
|
||||||
|
m.ctrl.Call(m, "InitializeProgressCounter", arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// InitializeProgressCounter indicates an expected call of InitializeProgressCounter.
|
||||||
|
func (mr *MockReporterMockRecorder) InitializeProgressCounter(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||||
|
mr.mock.ctrl.T.Helper()
|
||||||
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeProgressCounter", reflect.TypeOf((*MockReporter)(nil).InitializeProgressCounter), arg0, arg1, arg2)
|
||||||
|
}
|
||||||
|
|
||||||
// OnError mocks base method.
|
// OnError mocks base method.
|
||||||
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
|
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
|
|||||||
@ -48,7 +48,7 @@ func NewService(reporter reporter.Reporter,
|
|||||||
|
|
||||||
return &Service{
|
return &Service{
|
||||||
limits: limits,
|
limits: limits,
|
||||||
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem),
|
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem, panicHandler),
|
||||||
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
|
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
|
||||||
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
|
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
|
||||||
applyStage: NewApplyStage(applyCh),
|
applyStage: NewApplyStage(applyCh),
|
||||||
|
|||||||
@ -22,6 +22,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/gluon/logging"
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -44,7 +45,15 @@ func NewApplyStage(input ApplyStageInput) *ApplyStage {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *ApplyStage) Run(group *async.Group) {
|
func (a *ApplyStage) Run(group *async.Group) {
|
||||||
group.Once(a.run)
|
group.Once(func(ctx context.Context) {
|
||||||
|
logging.DoAnnotated(
|
||||||
|
ctx,
|
||||||
|
func(ctx context.Context) {
|
||||||
|
a.run(ctx)
|
||||||
|
},
|
||||||
|
logging.Labels{"sync-stage": "apply"},
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *ApplyStage) run(ctx context.Context) {
|
func (a *ApplyStage) run(ctx context.Context) {
|
||||||
|
|||||||
@ -24,6 +24,7 @@ import (
|
|||||||
"runtime"
|
"runtime"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/gluon/logging"
|
||||||
"github.com/ProtonMail/gluon/reporter"
|
"github.com/ProtonMail/gluon/reporter"
|
||||||
"github.com/ProtonMail/go-proton-api"
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
"github.com/ProtonMail/gopenpgp/v2/crypto"
|
||||||
@ -70,7 +71,15 @@ func NewBuildStage(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (b *BuildStage) Run(group *async.Group) {
|
func (b *BuildStage) Run(group *async.Group) {
|
||||||
group.Once(b.run)
|
group.Once(func(ctx context.Context) {
|
||||||
|
logging.DoAnnotated(
|
||||||
|
ctx,
|
||||||
|
func(ctx context.Context) {
|
||||||
|
b.run(ctx)
|
||||||
|
},
|
||||||
|
logging.Labels{"sync-stage": "build"},
|
||||||
|
)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *BuildStage) run(ctx context.Context) {
|
func (b *BuildStage) run(ctx context.Context) {
|
||||||
@ -94,7 +103,21 @@ func (b *BuildStage) run(ctx context.Context) {
|
|||||||
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
|
||||||
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
|
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
|
||||||
|
|
||||||
for _, chunk := range chunks {
|
// This stage will split our existing job into many smaller bits. We need to update the Parent Job so
|
||||||
|
// that it correctly tracks the lifetime of extra jobs. Additionally, we also need to make sure
|
||||||
|
// that only the last chunk contains the metadata to clear the cache.
|
||||||
|
chunkedJobs := req.chunkDivide(chunks)
|
||||||
|
|
||||||
|
for idx, chunk := range chunks {
|
||||||
|
if chunkedJobs[idx].checkCancelled() {
|
||||||
|
// Cancel all other chunks.
|
||||||
|
for i := idx + 1; i < len(chunkedJobs); i++ {
|
||||||
|
chunkedJobs[i].checkCancelled()
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
|
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
|
||||||
defer async.HandlePanic(b.panicHandler)
|
defer async.HandlePanic(b.panicHandler)
|
||||||
|
|
||||||
@ -135,21 +158,29 @@ func (b *BuildStage) run(ctx context.Context) {
|
|||||||
return BuildResult{}, nil
|
return BuildResult{}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil {
|
|
||||||
req.job.log.WithError(err).Error("Failed to remove failed message ID")
|
|
||||||
}
|
|
||||||
|
|
||||||
return res, nil
|
return res, nil
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
success := xslices.Filter(result, func(t BuildResult) bool {
|
||||||
|
return t.Update != nil
|
||||||
|
})
|
||||||
|
|
||||||
|
if len(success) > 0 {
|
||||||
|
successIDs := xslices.Map(success, func(t BuildResult) string {
|
||||||
|
return t.MessageID
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := req.job.state.RemFailedMessageID(req.getContext(), successIDs...); err != nil {
|
||||||
|
req.job.log.WithError(err).Error("Failed to remove failed message ID")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
b.output.Produce(ctx, ApplyRequest{
|
b.output.Produce(ctx, ApplyRequest{
|
||||||
childJob: req.childJob,
|
childJob: chunkedJobs[idx],
|
||||||
messages: xslices.Filter(result, func(t BuildResult) bool {
|
messages: success,
|
||||||
return t.Update != nil
|
|
||||||
}),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -153,7 +153,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
|
|||||||
buildError := errors.New("it failed")
|
buildError := errors.New("it failed")
|
||||||
|
|
||||||
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
|
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
|
||||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
|
||||||
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
||||||
"userID": "u",
|
"userID": "u",
|
||||||
"messageID": "MSG",
|
"messageID": "MSG",
|
||||||
@ -204,7 +204,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
|
|||||||
childJob := tj.job.newChildJob("f", 10)
|
childJob := tj.job.newChildJob("f", 10)
|
||||||
tj.job.end()
|
tj.job.end()
|
||||||
|
|
||||||
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG"))
|
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
|
||||||
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
|
||||||
"userID": "u",
|
"userID": "u",
|
||||||
"messageID": "MSG",
|
"messageID": "MSG",
|
||||||
|
|||||||
@ -112,8 +112,11 @@ func (d *DownloadStage) run(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var attData [][]byte
|
var attData [][]byte
|
||||||
if msg.NumAttachments > 0 {
|
|
||||||
attData = make([][]byte, msg.NumAttachments)
|
numAttachments := len(msg.Attachments)
|
||||||
|
|
||||||
|
if numAttachments > 0 {
|
||||||
|
attData = make([][]byte, numAttachments)
|
||||||
}
|
}
|
||||||
|
|
||||||
return proton.FullMessage{Message: msg, AttData: attData}, nil
|
return proton.FullMessage{Message: msg, AttData: attData}, nil
|
||||||
@ -139,7 +142,8 @@ func (d *DownloadStage) run(ctx context.Context) {
|
|||||||
attachmentIDs := make([]string, 0, len(result))
|
attachmentIDs := make([]string, 0, len(result))
|
||||||
|
|
||||||
for msgIdx, v := range result {
|
for msgIdx, v := range result {
|
||||||
for attIdx := 0; attIdx < v.NumAttachments; attIdx++ {
|
numAttachments := len(v.Attachments)
|
||||||
|
for attIdx := 0; attIdx < numAttachments; attIdx++ {
|
||||||
attachmentIndices = append(attachmentIndices, attachmentMeta{
|
attachmentIndices = append(attachmentIndices, attachmentMeta{
|
||||||
msgIdx: msgIdx,
|
msgIdx: msgIdx,
|
||||||
attIdx: attIdx,
|
attIdx: attIdx,
|
||||||
|
|||||||
@ -333,8 +333,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
|
|||||||
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
|
||||||
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
|
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
|
||||||
MessageMetadata: proton.MessageMetadata{
|
MessageMetadata: proton.MessageMetadata{
|
||||||
ID: "msg",
|
ID: "msg",
|
||||||
NumAttachments: 1,
|
|
||||||
},
|
},
|
||||||
Header: "",
|
Header: "",
|
||||||
ParsedHeaders: nil,
|
ParsedHeaders: nil,
|
||||||
@ -436,7 +435,7 @@ func buildDownloadStageAttachments(msg *proton.FullMessage, index int) {
|
|||||||
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
|
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
|
||||||
msg.Attachments = make([]proton.Attachment, count)
|
msg.Attachments = make([]proton.Attachment, count)
|
||||||
msg.AttData = make([][]byte, count)
|
msg.AttData = make([][]byte, count)
|
||||||
msg.NumAttachments = count
|
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
|
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
|
||||||
msg.Attachments[i] = proton.Attachment{
|
msg.Attachments[i] = proton.Attachment{
|
||||||
|
|||||||
@ -19,11 +19,12 @@ package syncservice
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
"github.com/ProtonMail/gluon/async"
|
"github.com/ProtonMail/gluon/async"
|
||||||
|
"github.com/ProtonMail/gluon/logging"
|
||||||
"github.com/ProtonMail/go-proton-api"
|
"github.com/ProtonMail/go-proton-api"
|
||||||
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
"github.com/ProtonMail/proton-bridge/v3/internal/network"
|
||||||
"github.com/bradenaw/juniper/xslices"
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -37,80 +38,91 @@ type MetadataStage struct {
|
|||||||
output MetadataStageOutput
|
output MetadataStageOutput
|
||||||
input MetadataStageInput
|
input MetadataStageInput
|
||||||
maxDownloadMem uint64
|
maxDownloadMem uint64
|
||||||
jobs []*metadataIterator
|
|
||||||
log *logrus.Entry
|
log *logrus.Entry
|
||||||
|
panicHandler async.PanicHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage {
|
func NewMetadataStage(
|
||||||
return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")}
|
input MetadataStageInput,
|
||||||
|
output MetadataStageOutput,
|
||||||
|
maxDownloadMem uint64,
|
||||||
|
panicHandler async.PanicHandler,
|
||||||
|
) *MetadataStage {
|
||||||
|
return &MetadataStage{
|
||||||
|
input: input,
|
||||||
|
output: output,
|
||||||
|
maxDownloadMem: maxDownloadMem,
|
||||||
|
log: logrus.WithField("sync-stage", "metadata"),
|
||||||
|
panicHandler: panicHandler,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
const MetadataPageSize = 150
|
const MetadataPageSize = 150
|
||||||
const MetadataMaxMessages = 250
|
const MetadataMaxMessages = 250
|
||||||
|
|
||||||
func (m MetadataStage) Run(group *async.Group) {
|
func (m *MetadataStage) Run(group *async.Group) {
|
||||||
group.Once(func(ctx context.Context) {
|
group.Once(func(ctx context.Context) {
|
||||||
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
|
logging.DoAnnotated(
|
||||||
|
ctx,
|
||||||
|
func(ctx context.Context) {
|
||||||
|
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
|
||||||
|
},
|
||||||
|
logging.Labels{"sync-stage": "metadata"},
|
||||||
|
)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
|
func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
|
||||||
defer m.output.Close()
|
defer m.output.Close()
|
||||||
|
|
||||||
|
group := async.NewGroup(ctx, m.panicHandler)
|
||||||
|
defer group.CancelAndWait()
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if ctx.Err() != nil {
|
job, err := m.input.Consume(ctx)
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if new job has been submitted
|
|
||||||
job, ok, err := m.input.TryConsume(ctx)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.log.WithError(err).Error("Error trying to retrieve more work")
|
if !(errors.Is(err, context.Canceled) || errors.Is(err, ErrNoMoreInput)) {
|
||||||
|
m.log.WithError(err).Error("Error trying to retrieve more work")
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if ok {
|
|
||||||
job.begin()
|
job.begin()
|
||||||
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
job.onError(err)
|
job.onError(err)
|
||||||
continue
|
continue
|
||||||
}
|
|
||||||
m.jobs = append(m.jobs, state)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Iterate over all jobs and produce work.
|
group.Once(func(ctx context.Context) {
|
||||||
for i := 0; i < len(m.jobs); {
|
for {
|
||||||
job := m.jobs[i]
|
if state.stage.ctx.Err() != nil {
|
||||||
|
state.stage.end()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
// If the job's context has been cancelled, remove from the list.
|
// Check for more work.
|
||||||
if job.stage.ctx.Err() != nil {
|
output, hasMore, err := state.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
|
||||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
if err != nil {
|
||||||
job.stage.end()
|
state.stage.onError(err)
|
||||||
continue
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// If there is actually more work, push it down the pipeline.
|
||||||
|
if len(output.ids) != 0 {
|
||||||
|
state.stage.metadataFetched += int64(len(output.ids))
|
||||||
|
job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount)
|
||||||
|
|
||||||
|
m.output.Produce(ctx, output)
|
||||||
|
}
|
||||||
|
|
||||||
|
// If this job has no more work left, signal completion.
|
||||||
|
if !hasMore {
|
||||||
|
state.stage.end()
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
})
|
||||||
// Check for more work.
|
|
||||||
output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
|
|
||||||
if err != nil {
|
|
||||||
job.stage.onError(err)
|
|
||||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
// If there is actually more work, push it down the pipeline.
|
|
||||||
if len(output.ids) != 0 {
|
|
||||||
m.output.Produce(ctx, output)
|
|
||||||
}
|
|
||||||
|
|
||||||
// If this job has no more work left, signal completion.
|
|
||||||
if !hasMore {
|
|
||||||
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
|
|
||||||
job.stage.end()
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@ -48,7 +48,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
|
|||||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
|
||||||
|
|
||||||
numMessages := 50
|
numMessages := 50
|
||||||
messageSize := 100
|
messageSize := 100
|
||||||
@ -86,7 +86,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
|
|||||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
|
||||||
@ -135,7 +135,7 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
|
|||||||
output := NewChannelConsumerProducer[DownloadRequest]()
|
output := NewChannelConsumerProducer[DownloadRequest]()
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
metadata := NewMetadataStage(input, output, TestMaxDownloadMem)
|
metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
|
||||||
|
|
||||||
numMessages := 50
|
numMessages := 50
|
||||||
messageSize := 100
|
messageSize := 100
|
||||||
|
|||||||
@ -31,7 +31,6 @@ var ErrNoMoreInput = errors.New("no more input")
|
|||||||
|
|
||||||
type StageInputConsumer[T any] interface {
|
type StageInputConsumer[T any] interface {
|
||||||
Consume(ctx context.Context) (T, error)
|
Consume(ctx context.Context) (T, error)
|
||||||
TryConsume(ctx context.Context) (T, bool, error)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChannelConsumerProducer[T any] struct {
|
type ChannelConsumerProducer[T any] struct {
|
||||||
@ -66,20 +65,3 @@ func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) {
|
|||||||
return t, nil
|
return t, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
var t T
|
|
||||||
return t, false, ctx.Err()
|
|
||||||
case t, ok := <-c.ch:
|
|
||||||
if !ok {
|
|
||||||
return t, false, ErrNoMoreInput
|
|
||||||
}
|
|
||||||
|
|
||||||
return t, true, nil
|
|
||||||
default:
|
|
||||||
var t T
|
|
||||||
return t, false, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
Reference in New Issue
Block a user