Files
proton-bridge/internal/store/sync_test.go
2022-05-31 15:54:39 +02:00

529 lines
12 KiB
Go

// Copyright (c) 2022 Proton AG
//
// This file is part of Proton Mail Bridge.
//
// Proton Mail Bridge is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// Proton Mail Bridge is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with Proton Mail Bridge. If not, see <https://www.gnu.org/licenses/>.
package store
import (
"context"
"sort"
"strconv"
"sync"
"testing"
"github.com/ProtonMail/proton-bridge/v2/pkg/pmapi"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
type mockLister struct {
err error
messageIDs []string
}
func (m *mockLister) ListMessages(_ context.Context, filter *pmapi.MessagesFilter) (msgs []*pmapi.Message, total int, err error) {
if m.err != nil {
return nil, 0, m.err
}
skipByID := true
skipByPaging := filter.PageSize * filter.Page
for idx := 0; idx < len(m.messageIDs); idx++ {
var messageID string
if !*filter.Desc {
messageID = m.messageIDs[idx]
if filter.BeginID == "" || messageID == filter.BeginID {
skipByID = false
}
} else {
messageID = m.messageIDs[len(m.messageIDs)-1-idx]
if filter.EndID == "" || messageID == filter.EndID {
skipByID = false
}
}
if skipByID {
continue
}
skipByPaging--
if skipByPaging > 0 {
continue
}
msgs = append(msgs, &pmapi.Message{
ID: messageID,
})
if len(msgs) == filter.PageSize || len(msgs) == filter.Limit {
break
}
if !*filter.Desc {
if messageID == filter.EndID {
break
}
} else {
if messageID == filter.BeginID {
break
}
}
}
return msgs, len(m.messageIDs), nil
}
type mockStoreSynchronizer struct {
locker sync.Locker
allMessageIDs []string
errCreateOrUpdateMessagesEvent error
createdMessageIDsByBatch [][]string
}
func newSyncer() *mockStoreSynchronizer {
return &mockStoreSynchronizer{
locker: &sync.Mutex{},
}
}
func (m *mockStoreSynchronizer) getAllMessageIDs() ([]string, error) {
m.locker.Lock()
defer m.locker.Unlock()
return m.allMessageIDs, nil
}
func (m *mockStoreSynchronizer) createOrUpdateMessagesEvent(messages []*pmapi.Message) error {
m.locker.Lock()
defer m.locker.Unlock()
if m.errCreateOrUpdateMessagesEvent != nil {
return m.errCreateOrUpdateMessagesEvent
}
createdMessageIDs := []string{}
for _, message := range messages {
createdMessageIDs = append(createdMessageIDs, message.ID)
}
m.createdMessageIDsByBatch = append(m.createdMessageIDsByBatch, createdMessageIDs)
return nil
}
func (m *mockStoreSynchronizer) deleteMessagesEvent([]string) error {
m.locker.Lock()
defer m.locker.Unlock()
return nil
}
func (m *mockStoreSynchronizer) saveSyncState(finishTime int64, idRanges []*syncIDRange, idsToBeDeleted []string) {
m.locker.Lock()
defer m.locker.Unlock()
}
func newTestSyncState(store storeSynchronizer, splitIDs ...string) *syncState {
syncState := newSyncState(store, 0, []*syncIDRange{}, []string{})
syncState.initIDRanges()
for _, splitID := range splitIDs {
syncState.addIDRange(splitID)
}
return syncState
}
func generateIDs(start, stop int) []string {
ids := []string{}
for x := start; x <= stop; x++ {
ids = append(ids, strconv.Itoa(x))
}
return ids
}
func generateIDsR(start, stop int) []string {
ids := []string{}
for x := start; x >= stop; x-- {
ids = append(ids, strconv.Itoa(x))
}
return ids
}
// Tests
func TestSyncAllMail(t *testing.T) { //nolint:funlen
m, clear := initMocks(t)
defer clear()
numberOfMessages := 10000
api := &mockLister{
messageIDs: generateIDs(1, numberOfMessages),
}
tests := []struct {
name string
idRanges []*syncIDRange
idsToBeDeleted []string
wantUpdatedIDs []string
wantNotUpdatedIDs []string
}{
{
"full sync",
[]*syncIDRange{},
[]string{},
api.messageIDs,
[]string{},
},
{
"continue with interrupted sync",
[]*syncIDRange{
{StartID: "2000", StopID: "2100"},
{StartID: "4000", StopID: "4200"},
{StartID: "9500", StopID: ""},
},
mergeArrays(generateIDs(2000, 2100), generateIDs(4000, 4200), generateIDs(9500, 10010)),
mergeArrays(generateIDs(2000, 2100), generateIDs(4000, 4200), generateIDs(9500, 10000)),
mergeArrays(generateIDs(1, 1999), generateIDs(2101, 3999), generateIDs(4201, 9459)),
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
store := newSyncer()
store.allMessageIDs = generateIDs(1, numberOfMessages+10)
syncState := newSyncState(store, 0, tc.idRanges, tc.idsToBeDeleted)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.Nil(t, err)
// Check all messages were created or updated.
updateMessageIDsMap := map[string]bool{}
for _, messageIDs := range store.createdMessageIDsByBatch {
for _, messageID := range messageIDs {
updateMessageIDsMap[messageID] = true
}
}
for _, messageID := range tc.wantUpdatedIDs {
assert.True(t, updateMessageIDsMap[messageID], "Message %s was not created/updated, but should", messageID)
}
for _, messageID := range tc.wantNotUpdatedIDs {
assert.False(t, updateMessageIDsMap[messageID], "Message %s was created/updated, but shouldn't", messageID)
}
// Check all messages were deleted.
idsToBeDeleted := syncState.getIDsToBeDeleted()
sort.Strings(idsToBeDeleted)
assert.Equal(t, generateIDs(numberOfMessages+1, numberOfMessages+10), idsToBeDeleted)
})
}
}
func mergeArrays(arrays ...[]string) []string {
result := []string{}
for _, array := range arrays {
result = append(result, array...)
}
return result
}
func TestSyncAllMail_FailedListing(t *testing.T) {
m, clear := initMocks(t)
defer clear()
numberOfMessages := 10000
store := newSyncer()
store.allMessageIDs = generateIDs(1, numberOfMessages+10)
api := &mockLister{
err: errors.New("error"),
messageIDs: generateIDs(1, numberOfMessages),
}
syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to list messages: error")
}
func TestSyncAllMail_FailedCreateOrUpdateMessage(t *testing.T) {
m, clear := initMocks(t)
defer clear()
numberOfMessages := 10000
store := newSyncer()
store.errCreateOrUpdateMessagesEvent = errors.New("error")
store.allMessageIDs = generateIDs(1, numberOfMessages+10)
api := &mockLister{
messageIDs: generateIDs(1, numberOfMessages),
}
syncState := newTestSyncState(store)
err := syncAllMail(m.panicHandler, store, api, syncState)
require.EqualError(t, err, "failed to sync group: failed to create or update messages: error")
}
func TestFindIDRanges(t *testing.T) { //nolint:funlen
store := newSyncer()
syncState := newTestSyncState(store)
tests := []struct {
name string
messageIDs []string
wantBatches [][]string
}{
{
"1k messages - 1 batch",
generateIDs(1, 1000),
[][]string{
{"", ""},
},
},
{
"1k messages not starting at 1",
generateIDs(1000, 2000),
[][]string{
{"", ""},
},
},
{
"2k messages - 2 batches",
generateIDs(1, 2000),
[][]string{
{"", "1050"},
{"1050", ""},
},
},
{
"4k messages - 3 batches",
generateIDs(1, 4000),
[][]string{
{"", "1350"},
{"1350", "2700"},
{"2700", ""},
},
},
{
"5k messages - 4 batches",
generateIDs(1, 5000),
[][]string{
{"", "1350"},
{"1350", "2700"},
{"2700", "4050"},
{"4050", ""},
},
},
{
"10k messages - 5 batches",
generateIDs(1, 10000),
[][]string{
{"", "2100"},
{"2100", "4200"},
{"4200", "6300"},
{"6300", "8400"},
{"8400", ""},
},
},
{
"150k messages - 5 batches",
generateIDs(1, 150000),
[][]string{
{"", "30000"},
{"30000", "60000"},
{"60000", "90000"},
{"90000", "120000"},
{"120000", ""},
},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := &mockLister{
messageIDs: tc.messageIDs,
}
err := findIDRanges(pmapi.AllMailLabel, api, syncState)
require.Nil(t, err)
require.Equal(t, len(tc.wantBatches), len(syncState.idRanges))
for idx, idRange := range syncState.idRanges {
want := tc.wantBatches[idx]
assert.Equal(t, want[0], idRange.StartID, "Start ID for IDs range %d does not match", idx)
assert.Equal(t, want[1], idRange.StopID, "Stop ID for IDs range %d does not match", idx)
}
})
}
}
func TestFindIDRanges_FailedListing(t *testing.T) {
store := newSyncer()
api := &mockLister{
err: errors.New("error"),
}
syncState := newTestSyncState(store)
err := findIDRanges(pmapi.AllMailLabel, api, syncState)
require.EqualError(t, err, "failed to get first ID and count: failed to list messages: error")
}
func TestGetSplitIDAndCount(t *testing.T) { //nolint:funlen
tests := []struct {
name string
err error
messageIDs []string
page int
wantID string
wantTotal int
wantErr string
}{
{
"1000 messages, first page",
nil,
generateIDs(1, 1000),
0,
"1",
1000,
"",
},
{
"1000 messages, 5th page",
nil,
generateIDs(1, 1000),
4,
"600",
1000,
"",
},
{
"one message, first page",
nil,
[]string{"1"},
0,
"1",
1,
"",
},
{
"no message, first page",
nil,
[]string{},
0,
"",
0,
"",
},
{
"listing error",
errors.New("error"),
generateIDs(1, 1000),
0,
"",
0,
"failed to list messages: error",
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
api := &mockLister{
err: tc.err,
messageIDs: tc.messageIDs,
}
id, total, err := getSplitIDAndCount(pmapi.AllMailLabel, api, tc.page)
if tc.wantErr == "" {
require.Nil(t, err)
} else {
require.EqualError(t, err, tc.wantErr)
}
assert.Equal(t, tc.wantID, id)
assert.Equal(t, tc.wantTotal, total)
})
}
}
func TestSyncBatch(t *testing.T) {
tests := []struct {
name string
batchIdx int
wantCreatedMessageIDsByBatch [][]string
}{
{
"first-batch",
0,
[][]string{generateIDsR(200, 51), generateIDsR(51, 1)},
},
{
"second-batch",
1,
[][]string{generateIDsR(400, 251), generateIDsR(251, 200)},
},
{
"third-batch",
2,
[][]string{generateIDsR(600, 451), generateIDsR(451, 400)},
},
{
"fourth-batch",
3,
[][]string{generateIDsR(800, 651), generateIDsR(651, 600)},
},
{
"fifth-batch",
4,
[][]string{generateIDsR(1000, 851), generateIDsR(851, 800)},
},
}
for _, tc := range tests {
tc := tc
t.Run(tc.name, func(t *testing.T) {
store := newSyncer()
api := &mockLister{
messageIDs: generateIDs(1, 1000),
}
err := testSyncBatch(t, store, api, tc.batchIdx, "200", "400", "600", "800")
require.Nil(t, err)
require.Equal(t, tc.wantCreatedMessageIDsByBatch, store.createdMessageIDsByBatch)
})
}
}
func TestSyncBatch_FailedListing(t *testing.T) {
store := newSyncer()
api := &mockLister{
err: errors.New("error"),
messageIDs: generateIDs(1, 1000),
}
err := testSyncBatch(t, store, api, 0)
require.EqualError(t, err, "failed to list messages: error")
}
func TestSyncBatch_FailedCreateOrUpdateMessage(t *testing.T) {
store := newSyncer()
store.errCreateOrUpdateMessagesEvent = errors.New("error")
api := &mockLister{
messageIDs: generateIDs(1, 1000),
}
err := testSyncBatch(t, store, api, 0)
require.EqualError(t, err, "failed to create or update messages: error")
}
func testSyncBatch(t *testing.T, store storeSynchronizer, api messageLister, rangeIdx int, splitIDs ...string) error { //nolint:unparam
syncState := newTestSyncState(store, splitIDs...)
idRange := syncState.idRanges[rangeIdx]
shouldStop := 0
return syncBatch(pmapi.AllMailLabel, store, api, syncState, idRange, &shouldStop)
}