diff --git a/tests/unit/parallel_ops/test_parallel_map.py b/tests/unit/parallel_ops/test_parallel_map.py index cbcf38c..cbdb8a5 100644 --- a/tests/unit/parallel_ops/test_parallel_map.py +++ b/tests/unit/parallel_ops/test_parallel_map.py @@ -1,66 +1,75 @@ -import unittest -from functools import cached_property from unittest.mock import MagicMock +import pytest + from depiction.parallel_ops import ParallelConfig from depiction.parallel_ops.parallel_map import ParallelMap -class TestParallelMap(unittest.TestCase): - def setUp(self) -> None: - self.mock_config = MagicMock(name="mock_config", spec=ParallelConfig) +@pytest.fixture +def mock_config() -> ParallelConfig: + return MagicMock(name="mock_config", spec=ParallelConfig, n_jobs=2) + + +@pytest.fixture +def mock_parallel(mock_config) -> ParallelMap: + return ParallelMap(config=mock_config) + + +def test_config(mock_parallel, mock_config): + assert mock_parallel.config == mock_config + + +def test_reduce_concat(mock_parallel): + results = [[1, 2, 3], [4], [5, 6]] + reduced = mock_parallel.reduce_concat(results) + assert reduced == [1, 2, 3, 4, 5, 6] - @cached_property - def mock_parallel(self) -> ParallelMap: - return ParallelMap(config=self.mock_config) - def test_config(self) -> None: - self.assertEqual(self.mock_config, self.mock_parallel.config) +def test_call_when_default(): + def mock_operation(x): + return x * 2 - def test_reduce_concat(self) -> None: - results = [[1, 2, 3], [4], [5, 6]] - reduced = self.mock_parallel.reduce_concat(results) - self.assertListEqual([1, 2, 3, 4, 5, 6], reduced) + tasks = [1, 2, 3, 4, 5] + mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) + mock_parallel = ParallelMap(config=mock_config) + result = mock_parallel(operation=mock_operation, tasks=tasks) - def test_call_when_default(self) -> None: - def mock_operation(x): - return x * 2 + assert result == [2, 4, 6, 8, 10] - tasks = [1, 2, 3, 4, 5] - self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) - result = self.mock_parallel(operation=mock_operation, tasks=tasks) - self.assertListEqual([2, 4, 6, 8, 10], result) +def test_call_when_bind_kwargs(): + def mock_operation(x, y): + return x * y - def test_call_when_bind_kwargs(self) -> None: - def mock_operation(x, y): - return x * y + tasks = [1, 2, 3, 4, 5] + mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) + mock_parallel = ParallelMap(config=mock_config) + result = mock_parallel(operation=mock_operation, tasks=tasks, bind_kwargs={"y": 3}) - tasks = [1, 2, 3, 4, 5] - self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) - result = self.mock_parallel(operation=mock_operation, tasks=tasks, bind_kwargs={"y": 3}) + assert result == [3, 6, 9, 12, 15] - self.assertListEqual([3, 6, 9, 12, 15], result) - def test_call_when_reduce_fn(self) -> None: - def mock_operation(x_list): - return [x * 2 for x in x_list] +@pytest.mark.parametrize( + "reduce_fn, expected_result", [(None, [[2, 4], [6, 8], [10]]), (ParallelMap.reduce_concat, [2, 4, 6, 8, 10])] +) +def test_call_when_reduce_fn(reduce_fn, expected_result): + def mock_operation(x_list): + return [x * 2 for x in x_list] - tasks = [[1, 2], [3, 4], [5]] - self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) + tasks = [[1, 2], [3, 4], [5]] + mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) + mock_parallel = ParallelMap(config=mock_config) - with self.subTest(reduce_fn="list"): - result_list = self.mock_parallel(operation=mock_operation, tasks=tasks) - self.assertListEqual([[2, 4], [6, 8], [10]], result_list) + result = mock_parallel(operation=mock_operation, tasks=tasks, reduce_fn=reduce_fn) + assert result == expected_result - with self.subTest(reduce_fn="concat"): - result_list = self.mock_parallel(operation=mock_operation, tasks=tasks, reduce_fn=ParallelMap.reduce_concat) - self.assertListEqual([2, 4, 6, 8, 10], result_list) - def test_repr(self) -> None: - self.mock_config = "1234" - self.assertEqual("ParallelMap(config='1234')", repr(self.mock_parallel)) +def test_repr(): + mock_config = "1234" + mock_parallel = ParallelMap(config=mock_config) + assert repr(mock_parallel) == "ParallelMap(config='1234')" if __name__ == "__main__": - unittest.main() + pytest.main()