fix(GODT-2829): Sync Service fixes

Fix tracking of child jobs. The build stage splits the incoming work
even further, but this was not reflected in the wait group counter. This
also fixes an issue where the cache was cleared to late.

Add more debug info for analysis.

Refactor sync state interface in order to have persistent sync rate.
This commit is contained in:
Leander Beernaert
2023-08-25 15:01:03 +02:00
parent 78f7cbdc79
commit aa77a67a1c
16 changed files with 221 additions and 189 deletions

View File

@ -299,11 +299,11 @@ EventSubscriber,MessageEventHandler,LabelEventHandler,AddressEventHandler,Refres
> internal/events/mocks/mocks.go > internal/events/mocks/mocks.go
mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \ mockgen --package mocks github.com/ProtonMail/proton-bridge/v3/internal/services/useridentity IdentityProvider,Telemetry \
> internal/services/useridentity/mocks/mocks.go > internal/services/useridentity/mocks/mocks.go
mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/sync" -package sync github.com/ProtonMail/proton-bridge/v3/internal/services/sync \ mockgen --self_package "github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice" -package syncservice github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice \
ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\ ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,\
StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \ StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier \
> tmp > tmp
mv tmp internal/services/sync/mocks_test.go mv tmp internal/services/syncservice/mocks_test.go
lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report lint: gofiles lint-golang lint-license lint-dependencies lint-changelog lint-bug-report

View File

@ -29,7 +29,6 @@ type APIClient interface {
GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error) GetLabels(ctx context.Context, labelTypes ...proton.LabelType) ([]proton.Label, error)
GetMessage(ctx context.Context, messageID string) (proton.Message, error) GetMessage(ctx context.Context, messageID string) (proton.Message, error)
GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error) GetMessageMetadataPage(ctx context.Context, page, pageSize int, filter proton.MessageFilter) ([]proton.MessageMetadata, error)
GetMessageIDs(ctx context.Context, afterID string) ([]string, error)
GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error) GetFullMessage(ctx context.Context, messageID string, scheduler proton.Scheduler, storageProvider proton.AttachmentAllocator) (proton.FullMessage, error)
GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error GetAttachmentInto(ctx context.Context, attachmentID string, reader io.ReaderFrom) error
GetAttachment(ctx context.Context, attachmentID string) ([]byte, error) GetAttachment(ctx context.Context, attachmentID string) ([]byte, error)

View File

@ -117,7 +117,11 @@ func (t *Handler) Execute(
} }
t.log.WithField("duration", time.Since(start)).Info("Finished user sync") t.log.WithField("duration", time.Since(start)).Info("Finished user sync")
t.syncFinishedCh <- err select {
case <-ctx.Done():
return
case t.syncFinishedCh <- err:
}
}) })
} }
@ -179,8 +183,12 @@ func (t *Handler) run(ctx context.Context,
if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil { if err := t.syncState.SetMessageCount(ctx, totalMessageCount); err != nil {
return fmt.Errorf("failed to store message count: %w", err) return fmt.Errorf("failed to store message count: %w", err)
} }
syncStatus.TotalMessageCount = totalMessageCount
} }
syncReporter.InitializeProgressCounter(ctx, syncStatus.NumSyncedMessages, syncStatus.TotalMessageCount)
if !syncStatus.HasMessages { if !syncStatus.HasMessages {
t.log.Info("Syncing messages") t.log.Info("Syncing messages")
@ -198,6 +206,11 @@ func (t *Handler) run(ctx context.Context,
t.log, t.log,
) )
stageContext.metadataFetched = syncStatus.NumSyncedMessages
stageContext.totalMessageCount = syncStatus.TotalMessageCount
defer stageContext.Close()
t.regulator.Sync(ctx, stageContext) t.regulator.Sync(ctx, stageContext)
// Wait on reply // Wait on reply

View File

@ -57,6 +57,7 @@ func TestTask_NoStateAndSucceeds(t *testing.T) {
}) })
call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil) call2 := tt.syncState.EXPECT().SetHasLabels(gomock.Any(), gomock.Eq(true)).After(call1).Times(1).Return(nil)
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) call5 := tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) { tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).After(call5).Times(1).DoAndReturn(func(_ context.Context) (Status, error) {
@ -125,6 +126,7 @@ func TestTask_StateHasLabels(t *testing.T) {
call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil) call3 := tt.syncState.EXPECT().SetMessageCount(gomock.Any(), gomock.Eq(MessageTotal)).After(call2).Times(1).Return(nil)
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
} }
{ {
@ -170,6 +172,7 @@ func TestTask_StateHasLabelsAndMessageCount(t *testing.T) {
}) })
call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil) call4 := tt.syncState.EXPECT().SetLastMessageID(gomock.Any(), gomock.Eq(MessageID), gomock.Eq(MessageDelta)).After(call3).Times(1).Return(nil)
tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil) tt.syncState.EXPECT().SetHasMessages(gomock.Any(), gomock.Eq(true)).After(call4).Times(1).Return(nil)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
} }
tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta)) tt.syncReporter.EXPECT().OnProgress(gomock.Any(), gomock.Eq(MessageDelta))
@ -219,6 +222,8 @@ func TestTask_RepeatsOnSyncFailure(t *testing.T) {
tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta) tt.addMessageSyncCompletedExpectation(MessageID, MessageDelta)
tt.syncReporter.EXPECT().InitializeProgressCounter(gomock.Any(), gomock.Any(), gomock.Eq(MessageTotal))
{ {
call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) { call0 := tt.syncState.EXPECT().GetSyncStatus(gomock.Any()).DoAndReturn(func(_ context.Context) (Status, error) {
return Status{ return Status{

View File

@ -28,8 +28,8 @@ import (
) )
type StateProvider interface { type StateProvider interface {
AddFailedMessageID(context.Context, string) error AddFailedMessageID(context.Context, ...string) error
RemFailedMessageID(context.Context, string) error RemFailedMessageID(context.Context, ...string) error
GetSyncStatus(context.Context) (Status, error) GetSyncStatus(context.Context) (Status, error)
ClearSyncStatus(context.Context) error ClearSyncStatus(context.Context) error
SetHasLabels(context.Context, bool) error SetHasLabels(context.Context, bool) error
@ -85,4 +85,5 @@ type Reporter interface {
OnFinished(ctx context.Context) OnFinished(ctx context.Context)
OnError(ctx context.Context, err error) OnError(ctx context.Context, err error)
OnProgress(ctx context.Context, delta int64) OnProgress(ctx context.Context, delta int64)
InitializeProgressCounter(ctx context.Context, current int64, total int64)
} }

View File

@ -24,6 +24,7 @@ import (
"sync" "sync"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/go-proton-api"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -54,6 +55,9 @@ type Job struct {
panicHandler async.PanicHandler panicHandler async.PanicHandler
downloadCache *DownloadCache downloadCache *DownloadCache
metadataFetched int64
totalMessageCount int64
} }
func NewJob(ctx context.Context, func NewJob(ctx context.Context,
@ -178,6 +182,36 @@ func (s *childJob) userID() string {
return s.job.userID return s.job.userID
} }
func (s *childJob) chunkDivide(chunks [][]proton.FullMessage) []childJob {
numChunks := len(chunks)
if numChunks == 1 {
return []childJob{*s}
}
result := make([]childJob, numChunks)
for i := 0; i < numChunks-1; i++ {
result[i] = s.job.newChildJob(chunks[i][len(chunks[i])-1].ID, int64(len(chunks[i])))
collectIDs(&result[i], chunks[i])
}
result[numChunks-1] = *s
collectIDs(&result[numChunks-1], chunks[numChunks-1])
return result
}
func collectIDs(j *childJob, msgs []proton.FullMessage) {
j.cachedAttachmentIDs = make([]string, 0, len(msgs))
j.cachedMessageIDs = make([]string, 0, len(msgs))
for _, msg := range msgs {
j.cachedMessageIDs = append(j.cachedMessageIDs, msg.ID)
for _, attach := range msg.Attachments {
j.cachedAttachmentIDs = append(j.cachedAttachmentIDs, attach.ID)
}
}
}
func (s *childJob) onFinished(ctx context.Context) { func (s *childJob) onFinished(ctx context.Context) {
s.job.log.Infof("Child job finished") s.job.log.Infof("Child job finished")
s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount) s.job.onJobFinished(ctx, s.lastMessageID, s.messageCount)

View File

@ -1,7 +1,7 @@
// Code generated by MockGen. DO NOT EDIT. // Code generated by MockGen. DO NOT EDIT.
// Source: github.com/ProtonMail/proton-bridge/v3/internal/services/sync (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier) // Source: github.com/ProtonMail/proton-bridge/v3/internal/services/syncservice (interfaces: ApplyStageInput,BuildStageInput,BuildStageOutput,DownloadStageInput,DownloadStageOutput,MetadataStageInput,MetadataStageOutput,StateProvider,Regulator,UpdateApplier,MessageBuilder,APIClient,Reporter,DownloadRateModifier)
// Package sync is a generated GoMock package. // Package syncservice is a generated GoMock package.
package syncservice package syncservice
import ( import (
@ -53,22 +53,6 @@ func (mr *MockApplyStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockApplyStageInput)(nil).Consume), arg0)
} }
// TryConsume mocks base method.
func (m *MockApplyStageInput) TryConsume(arg0 context.Context) (ApplyRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(ApplyRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockApplyStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockApplyStageInput)(nil).TryConsume), arg0)
}
// MockBuildStageInput is a mock of BuildStageInput interface. // MockBuildStageInput is a mock of BuildStageInput interface.
type MockBuildStageInput struct { type MockBuildStageInput struct {
ctrl *gomock.Controller ctrl *gomock.Controller
@ -107,22 +91,6 @@ func (mr *MockBuildStageInputMockRecorder) Consume(arg0 interface{}) *gomock.Cal
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockBuildStageInput)(nil).Consume), arg0)
} }
// TryConsume mocks base method.
func (m *MockBuildStageInput) TryConsume(arg0 context.Context) (BuildRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(BuildRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockBuildStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockBuildStageInput)(nil).TryConsume), arg0)
}
// MockBuildStageOutput is a mock of BuildStageOutput interface. // MockBuildStageOutput is a mock of BuildStageOutput interface.
type MockBuildStageOutput struct { type MockBuildStageOutput struct {
ctrl *gomock.Controller ctrl *gomock.Controller
@ -208,22 +176,6 @@ func (mr *MockDownloadStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockDownloadStageInput)(nil).Consume), arg0)
} }
// TryConsume mocks base method.
func (m *MockDownloadStageInput) TryConsume(arg0 context.Context) (DownloadRequest, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(DownloadRequest)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockDownloadStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockDownloadStageInput)(nil).TryConsume), arg0)
}
// MockDownloadStageOutput is a mock of DownloadStageOutput interface. // MockDownloadStageOutput is a mock of DownloadStageOutput interface.
type MockDownloadStageOutput struct { type MockDownloadStageOutput struct {
ctrl *gomock.Controller ctrl *gomock.Controller
@ -309,22 +261,6 @@ func (mr *MockMetadataStageInputMockRecorder) Consume(arg0 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Consume", reflect.TypeOf((*MockMetadataStageInput)(nil).Consume), arg0)
} }
// TryConsume mocks base method.
func (m *MockMetadataStageInput) TryConsume(arg0 context.Context) (*Job, bool, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "TryConsume", arg0)
ret0, _ := ret[0].(*Job)
ret1, _ := ret[1].(bool)
ret2, _ := ret[2].(error)
return ret0, ret1, ret2
}
// TryConsume indicates an expected call of TryConsume.
func (mr *MockMetadataStageInputMockRecorder) TryConsume(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "TryConsume", reflect.TypeOf((*MockMetadataStageInput)(nil).TryConsume), arg0)
}
// MockMetadataStageOutput is a mock of MetadataStageOutput interface. // MockMetadataStageOutput is a mock of MetadataStageOutput interface.
type MockMetadataStageOutput struct { type MockMetadataStageOutput struct {
ctrl *gomock.Controller ctrl *gomock.Controller
@ -396,17 +332,22 @@ func (m *MockStateProvider) EXPECT() *MockStateProviderMockRecorder {
} }
// AddFailedMessageID mocks base method. // AddFailedMessageID mocks base method.
func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 string) error { func (m *MockStateProvider) AddFailedMessageID(arg0 context.Context, arg1 ...string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AddFailedMessageID", arg0, arg1) varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "AddFailedMessageID", varargs...)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// AddFailedMessageID indicates an expected call of AddFailedMessageID. // AddFailedMessageID indicates an expected call of AddFailedMessageID.
func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0, arg1 interface{}) *gomock.Call { func (mr *MockStateProviderMockRecorder) AddFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), arg0, arg1) varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).AddFailedMessageID), varargs...)
} }
// ClearSyncStatus mocks base method. // ClearSyncStatus mocks base method.
@ -439,17 +380,22 @@ func (mr *MockStateProviderMockRecorder) GetSyncStatus(arg0 interface{}) *gomock
} }
// RemFailedMessageID mocks base method. // RemFailedMessageID mocks base method.
func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 string) error { func (m *MockStateProvider) RemFailedMessageID(arg0 context.Context, arg1 ...string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RemFailedMessageID", arg0, arg1) varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "RemFailedMessageID", varargs...)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// RemFailedMessageID indicates an expected call of RemFailedMessageID. // RemFailedMessageID indicates an expected call of RemFailedMessageID.
func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0, arg1 interface{}) *gomock.Call { func (mr *MockStateProviderMockRecorder) RemFailedMessageID(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), arg0, arg1) varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemFailedMessageID", reflect.TypeOf((*MockStateProvider)(nil).RemFailedMessageID), varargs...)
} }
// SetHasLabels mocks base method. // SetHasLabels mocks base method.
@ -777,21 +723,6 @@ func (mr *MockAPIClientMockRecorder) GetMessage(arg0, arg1 interface{}) *gomock.
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessage", reflect.TypeOf((*MockAPIClient)(nil).GetMessage), arg0, arg1)
} }
// GetMessageIDs mocks base method.
func (m *MockAPIClient) GetMessageIDs(arg0 context.Context, arg1 string) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetMessageIDs", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetMessageIDs indicates an expected call of GetMessageIDs.
func (mr *MockAPIClientMockRecorder) GetMessageIDs(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetMessageIDs", reflect.TypeOf((*MockAPIClient)(nil).GetMessageIDs), arg0, arg1)
}
// GetMessageMetadataPage mocks base method. // GetMessageMetadataPage mocks base method.
func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) { func (m *MockAPIClient) GetMessageMetadataPage(arg0 context.Context, arg1, arg2 int, arg3 proton.MessageFilter) ([]proton.MessageMetadata, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -830,6 +761,18 @@ func (m *MockReporter) EXPECT() *MockReporterMockRecorder {
return m.recorder return m.recorder
} }
// InitializeProgressCounter mocks base method.
func (m *MockReporter) InitializeProgressCounter(arg0 context.Context, arg1, arg2 int64) {
m.ctrl.T.Helper()
m.ctrl.Call(m, "InitializeProgressCounter", arg0, arg1, arg2)
}
// InitializeProgressCounter indicates an expected call of InitializeProgressCounter.
func (mr *MockReporterMockRecorder) InitializeProgressCounter(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InitializeProgressCounter", reflect.TypeOf((*MockReporter)(nil).InitializeProgressCounter), arg0, arg1, arg2)
}
// OnError mocks base method. // OnError mocks base method.
func (m *MockReporter) OnError(arg0 context.Context, arg1 error) { func (m *MockReporter) OnError(arg0 context.Context, arg1 error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -48,7 +48,7 @@ func NewService(reporter reporter.Reporter,
return &Service{ return &Service{
limits: limits, limits: limits,
metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem), metadataStage: NewMetadataStage(metaCh, downloadCh, limits.DownloadRequestMem, panicHandler),
downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler), downloadStage: NewDownloadStage(downloadCh, buildCh, 20, panicHandler),
buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter), buildStage: NewBuildStage(buildCh, applyCh, limits.MessageBuildMem, panicHandler, reporter),
applyStage: NewApplyStage(applyCh), applyStage: NewApplyStage(applyCh),

View File

@ -22,6 +22,7 @@ import (
"errors" "errors"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -44,7 +45,15 @@ func NewApplyStage(input ApplyStageInput) *ApplyStage {
} }
func (a *ApplyStage) Run(group *async.Group) { func (a *ApplyStage) Run(group *async.Group) {
group.Once(a.run) group.Once(func(ctx context.Context) {
logging.DoAnnotated(
ctx,
func(ctx context.Context) {
a.run(ctx)
},
logging.Labels{"sync-stage": "apply"},
)
})
} }
func (a *ApplyStage) run(ctx context.Context) { func (a *ApplyStage) run(ctx context.Context) {

View File

@ -24,6 +24,7 @@ import (
"runtime" "runtime"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/gluon/reporter" "github.com/ProtonMail/gluon/reporter"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/gopenpgp/v2/crypto" "github.com/ProtonMail/gopenpgp/v2/crypto"
@ -70,7 +71,15 @@ func NewBuildStage(
} }
func (b *BuildStage) Run(group *async.Group) { func (b *BuildStage) Run(group *async.Group) {
group.Once(b.run) group.Once(func(ctx context.Context) {
logging.DoAnnotated(
ctx,
func(ctx context.Context) {
b.run(ctx)
},
logging.Labels{"sync-stage": "build"},
)
})
} }
func (b *BuildStage) run(ctx context.Context) { func (b *BuildStage) run(ctx context.Context) {
@ -94,7 +103,21 @@ func (b *BuildStage) run(ctx context.Context) {
err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error { err = req.job.messageBuilder.WithKeys(func(_ *crypto.KeyRing, addrKRs map[string]*crypto.KeyRing) error {
chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem) chunks := chunkSyncBuilderBatch(req.batch, b.maxBuildMem)
for _, chunk := range chunks { // This stage will split our existing job into many smaller bits. We need to update the Parent Job so
// that it correctly tracks the lifetime of extra jobs. Additionally, we also need to make sure
// that only the last chunk contains the metadata to clear the cache.
chunkedJobs := req.chunkDivide(chunks)
for idx, chunk := range chunks {
if chunkedJobs[idx].checkCancelled() {
// Cancel all other chunks.
for i := idx + 1; i < len(chunkedJobs); i++ {
chunkedJobs[i].checkCancelled()
}
return nil
}
result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) { result, err := parallel.MapContext(ctx, maxMessagesInParallel, chunk, func(ctx context.Context, msg proton.FullMessage) (BuildResult, error) {
defer async.HandlePanic(b.panicHandler) defer async.HandlePanic(b.panicHandler)
@ -135,21 +158,29 @@ func (b *BuildStage) run(ctx context.Context) {
return BuildResult{}, nil return BuildResult{}, nil
} }
if err := req.job.state.RemFailedMessageID(req.getContext(), res.MessageID); err != nil {
req.job.log.WithError(err).Error("Failed to remove failed message ID")
}
return res, nil return res, nil
}) })
if err != nil { if err != nil {
return err return err
} }
success := xslices.Filter(result, func(t BuildResult) bool {
return t.Update != nil
})
if len(success) > 0 {
successIDs := xslices.Map(success, func(t BuildResult) string {
return t.MessageID
})
if err := req.job.state.RemFailedMessageID(req.getContext(), successIDs...); err != nil {
req.job.log.WithError(err).Error("Failed to remove failed message ID")
}
}
b.output.Produce(ctx, ApplyRequest{ b.output.Produce(ctx, ApplyRequest{
childJob: req.childJob, childJob: chunkedJobs[idx],
messages: xslices.Filter(result, func(t BuildResult) bool { messages: success,
return t.Update != nil
}),
}) })
} }

View File

@ -153,7 +153,7 @@ func TestBuildStage_BuildFailureIsReportedButDoesNotCancelJob(t *testing.T) {
buildError := errors.New("it failed") buildError := errors.New("it failed")
tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError) tj.messageBuilder.EXPECT().BuildMessage(gomock.Eq(labels), gomock.Eq(msg), gomock.Any(), gomock.Any()).Return(BuildResult{}, buildError)
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG")) tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{ mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
"userID": "u", "userID": "u",
"messageID": "MSG", "messageID": "MSG",
@ -204,7 +204,7 @@ func TestBuildStage_FailedToLocateKeyRingIsReportedButDoesNotFailBuild(t *testin
childJob := tj.job.newChildJob("f", 10) childJob := tj.job.newChildJob("f", 10)
tj.job.end() tj.job.end()
tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq("MSG")) tj.state.EXPECT().AddFailedMessageID(gomock.Any(), gomock.Eq([]string{"MSG"}))
mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{ mockReporter.EXPECT().ReportMessageWithContext(gomock.Any(), gomock.Eq(reporter.Context{
"userID": "u", "userID": "u",
"messageID": "MSG", "messageID": "MSG",

View File

@ -112,8 +112,11 @@ func (d *DownloadStage) run(ctx context.Context) {
} }
var attData [][]byte var attData [][]byte
if msg.NumAttachments > 0 {
attData = make([][]byte, msg.NumAttachments) numAttachments := len(msg.Attachments)
if numAttachments > 0 {
attData = make([][]byte, numAttachments)
} }
return proton.FullMessage{Message: msg, AttData: attData}, nil return proton.FullMessage{Message: msg, AttData: attData}, nil
@ -139,7 +142,8 @@ func (d *DownloadStage) run(ctx context.Context) {
attachmentIDs := make([]string, 0, len(result)) attachmentIDs := make([]string, 0, len(result))
for msgIdx, v := range result { for msgIdx, v := range result {
for attIdx := 0; attIdx < v.NumAttachments; attIdx++ { numAttachments := len(v.Attachments)
for attIdx := 0; attIdx < numAttachments; attIdx++ {
attachmentIndices = append(attachmentIndices, attachmentMeta{ attachmentIndices = append(attachmentIndices, attachmentMeta{
msgIdx: msgIdx, msgIdx: msgIdx,
attIdx: attIdx, attIdx: attIdx,

View File

@ -333,8 +333,7 @@ func TestDownloadStage_JobAbortsOnAttachmentDownloadError(t *testing.T) {
tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{}) tj := newTestJob(jobCtx, mockCtrl, "", map[string]proton.Label{})
tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{ tj.client.EXPECT().GetMessage(gomock.Any(), gomock.Any()).Return(proton.Message{
MessageMetadata: proton.MessageMetadata{ MessageMetadata: proton.MessageMetadata{
ID: "msg", ID: "msg",
NumAttachments: 1,
}, },
Header: "", Header: "",
ParsedHeaders: nil, ParsedHeaders: nil,
@ -436,7 +435,7 @@ func buildDownloadStageAttachments(msg *proton.FullMessage, index int) {
func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) { func genDownloadStageAttachmentInfo(msg *proton.FullMessage, msgIdx int, count int) {
msg.Attachments = make([]proton.Attachment, count) msg.Attachments = make([]proton.Attachment, count)
msg.AttData = make([][]byte, count) msg.AttData = make([][]byte, count)
msg.NumAttachments = count
for i := 0; i < count; i++ { for i := 0; i < count; i++ {
data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i) data := fmt.Sprintf("msg-%v-att-%v", msgIdx, i)
msg.Attachments[i] = proton.Attachment{ msg.Attachments[i] = proton.Attachment{

View File

@ -19,11 +19,12 @@ package syncservice
import ( import (
"context" "context"
"errors"
"github.com/ProtonMail/gluon/async" "github.com/ProtonMail/gluon/async"
"github.com/ProtonMail/gluon/logging"
"github.com/ProtonMail/go-proton-api" "github.com/ProtonMail/go-proton-api"
"github.com/ProtonMail/proton-bridge/v3/internal/network" "github.com/ProtonMail/proton-bridge/v3/internal/network"
"github.com/bradenaw/juniper/xslices"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
@ -37,80 +38,91 @@ type MetadataStage struct {
output MetadataStageOutput output MetadataStageOutput
input MetadataStageInput input MetadataStageInput
maxDownloadMem uint64 maxDownloadMem uint64
jobs []*metadataIterator
log *logrus.Entry log *logrus.Entry
panicHandler async.PanicHandler
} }
func NewMetadataStage(input MetadataStageInput, output MetadataStageOutput, maxDownloadMem uint64) *MetadataStage { func NewMetadataStage(
return &MetadataStage{input: input, output: output, maxDownloadMem: maxDownloadMem, log: logrus.WithField("sync-stage", "metadata")} input MetadataStageInput,
output MetadataStageOutput,
maxDownloadMem uint64,
panicHandler async.PanicHandler,
) *MetadataStage {
return &MetadataStage{
input: input,
output: output,
maxDownloadMem: maxDownloadMem,
log: logrus.WithField("sync-stage", "metadata"),
panicHandler: panicHandler,
}
} }
const MetadataPageSize = 150 const MetadataPageSize = 150
const MetadataMaxMessages = 250 const MetadataMaxMessages = 250
func (m MetadataStage) Run(group *async.Group) { func (m *MetadataStage) Run(group *async.Group) {
group.Once(func(ctx context.Context) { group.Once(func(ctx context.Context) {
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{}) logging.DoAnnotated(
ctx,
func(ctx context.Context) {
m.run(ctx, MetadataPageSize, MetadataMaxMessages, &network.ExpCoolDown{})
},
logging.Labels{"sync-stage": "metadata"},
)
}) })
} }
func (m MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) { func (m *MetadataStage) run(ctx context.Context, metadataPageSize int, maxMessages int, coolDown network.CoolDownProvider) {
defer m.output.Close() defer m.output.Close()
group := async.NewGroup(ctx, m.panicHandler)
defer group.CancelAndWait()
for { for {
if ctx.Err() != nil { job, err := m.input.Consume(ctx)
return
}
// Check if new job has been submitted
job, ok, err := m.input.TryConsume(ctx)
if err != nil { if err != nil {
m.log.WithError(err).Error("Error trying to retrieve more work") if !(errors.Is(err, context.Canceled) || errors.Is(err, ErrNoMoreInput)) {
m.log.WithError(err).Error("Error trying to retrieve more work")
}
return return
} }
if ok {
job.begin() job.begin()
state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown) state, err := newMetadataIterator(job.ctx, job, metadataPageSize, coolDown)
if err != nil { if err != nil {
job.onError(err) job.onError(err)
continue continue
}
m.jobs = append(m.jobs, state)
} }
// Iterate over all jobs and produce work. group.Once(func(ctx context.Context) {
for i := 0; i < len(m.jobs); { for {
job := m.jobs[i] if state.stage.ctx.Err() != nil {
state.stage.end()
return
}
// If the job's context has been cancelled, remove from the list. // Check for more work.
if job.stage.ctx.Err() != nil { output, hasMore, err := state.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1) if err != nil {
job.stage.end() state.stage.onError(err)
continue return
}
// If there is actually more work, push it down the pipeline.
if len(output.ids) != 0 {
state.stage.metadataFetched += int64(len(output.ids))
job.log.Debugf("Metada collected: %v/%v", state.stage.metadataFetched, state.stage.totalMessageCount)
m.output.Produce(ctx, output)
}
// If this job has no more work left, signal completion.
if !hasMore {
state.stage.end()
return
}
} }
})
// Check for more work.
output, hasMore, err := job.Next(m.maxDownloadMem, metadataPageSize, maxMessages)
if err != nil {
job.stage.onError(err)
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
continue
}
// If there is actually more work, push it down the pipeline.
if len(output.ids) != 0 {
m.output.Produce(ctx, output)
}
// If this job has no more work left, signal completion.
if !hasMore {
m.jobs = xslices.RemoveUnordered(m.jobs, i, 1)
job.stage.end()
continue
}
i++
}
} }
} }

View File

@ -48,7 +48,7 @@ func TestMetadataStage_RunFinishesWith429(t *testing.T) {
output := NewChannelConsumerProducer[DownloadRequest]() output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem) metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
numMessages := 50 numMessages := 50
messageSize := 100 messageSize := 100
@ -86,7 +86,7 @@ func TestMetadataStage_JobCorrectlyFinishesAfterCancel(t *testing.T) {
output := NewChannelConsumerProducer[DownloadRequest]() output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem) metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
go func() { go func() {
metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{}) metadata.run(ctx, TestMetadataPageSize, TestMaxMessages, &network.NoCoolDown{})
@ -135,7 +135,7 @@ func TestMetadataStage_RunInterleaved(t *testing.T) {
output := NewChannelConsumerProducer[DownloadRequest]() output := NewChannelConsumerProducer[DownloadRequest]()
ctx, cancel := context.WithCancel(context.Background()) ctx, cancel := context.WithCancel(context.Background())
metadata := NewMetadataStage(input, output, TestMaxDownloadMem) metadata := NewMetadataStage(input, output, TestMaxDownloadMem, &async.NoopPanicHandler{})
numMessages := 50 numMessages := 50
messageSize := 100 messageSize := 100

View File

@ -31,7 +31,6 @@ var ErrNoMoreInput = errors.New("no more input")
type StageInputConsumer[T any] interface { type StageInputConsumer[T any] interface {
Consume(ctx context.Context) (T, error) Consume(ctx context.Context) (T, error)
TryConsume(ctx context.Context) (T, bool, error)
} }
type ChannelConsumerProducer[T any] struct { type ChannelConsumerProducer[T any] struct {
@ -66,20 +65,3 @@ func (c ChannelConsumerProducer[T]) Consume(ctx context.Context) (T, error) {
return t, nil return t, nil
} }
} }
func (c ChannelConsumerProducer[T]) TryConsume(ctx context.Context) (T, bool, error) {
select {
case <-ctx.Done():
var t T
return t, false, ctx.Err()
case t, ok := <-c.ch:
if !ok {
return t, false, ErrNoMoreInput
}
return t, true, nil
default:
var t T
return t, false, nil
}
}