// Copyright (c) 2021 Proton Technologies AG // // This file is part of ProtonMail Bridge. // // ProtonMail Bridge is free software: you can redistribute it and/or modify // it under the terms of the GNU General Public License as published by // the Free Software Foundation, either version 3 of the License, or // (at your option) any later version. // // ProtonMail Bridge is distributed in the hope that it will be useful, // but WITHOUT ANY WARRANTY; without even the implied warranty of // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the // GNU General Public License for more details. // // You should have received a copy of the GNU General Public License // along with ProtonMail Bridge. If not, see . package parallel import ( "errors" "fmt" "math" "testing" "time" r "github.com/stretchr/testify/require" ) // nolint[gochecknoglobals] var ( testInput = []interface{}{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} wantOutput = []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} testProcessSleep = 100 // ms runParallelTimeOverhead = 150 // ms ) func TestParallel(t *testing.T) { workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} for _, workers := range workersTests { workers := workers t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) { collected := make([]int, 0) collect := func(idx int, value interface{}) error { collected = append(collected, value.(int)) return nil } tstart := time.Now() err := RunParallel(workers, testInput, processSleep, collect) duration := time.Since(tstart) r.Nil(t, err) r.Equal(t, wantOutput, collected) // Check the order is always kept. wantMinDuration := int(math.Ceil(float64(len(testInput))/float64(workers))) * testProcessSleep wantMaxDuration := wantMinDuration + runParallelTimeOverhead r.True(t, duration.Nanoseconds() > int64(wantMinDuration*1000000), "Duration too short: %v (expected: %v)", duration, wantMinDuration) r.True(t, duration.Nanoseconds() < int64(wantMaxDuration*1000000), "Duration too long: %v (expected: %v)", duration, wantMaxDuration) }) } } func TestParallelEmptyInput(t *testing.T) { workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} for _, workers := range workersTests { workers := workers t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) { err := RunParallel(workers, []interface{}{}, processSleep, collectNil) r.Nil(t, err) }) } } func TestParallelErrorInProcess(t *testing.T) { workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} for _, workers := range workersTests { workers := workers t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) { var lastCollected int process := func(value interface{}) (interface{}, error) { time.Sleep(10 * time.Millisecond) if value.(int) == 5 { return nil, errors.New("Error") } return value, nil } collect := func(idx int, value interface{}) error { lastCollected = value.(int) //nolint[forcetypeassert] return nil } err := RunParallel(workers, testInput, process, collect) r.EqualError(t, err, "Error") time.Sleep(10 * time.Millisecond) r.True(t, lastCollected < 5, "Last collected cannot be higher that 5, got: %d", lastCollected) }) } } func TestParallelErrorInCollect(t *testing.T) { workersTests := []int{1, 2, 3, 4, 5, 6, 7, 8, 9, 10} for _, workers := range workersTests { workers := workers t.Run(fmt.Sprintf("%d", workers), func(t *testing.T) { collect := func(idx int, value interface{}) error { if value.(int) == 5 { return errors.New("Error") } return nil } err := RunParallel(workers, testInput, processSleep, collect) r.EqualError(t, err, "Error") }) } } func processSleep(value interface{}) (interface{}, error) { time.Sleep(time.Duration(testProcessSleep) * time.Millisecond) return value.(int), nil } func collectNil(idx int, value interface{}) error { return nil }