-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
255d39b
commit ae890c4
Showing
1 changed file
with
52 additions
and
43 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,66 +1,75 @@ | ||
import unittest | ||
from functools import cached_property | ||
from unittest.mock import MagicMock | ||
|
||
import pytest | ||
|
||
from depiction.parallel_ops import ParallelConfig | ||
from depiction.parallel_ops.parallel_map import ParallelMap | ||
|
||
|
||
class TestParallelMap(unittest.TestCase): | ||
def setUp(self) -> None: | ||
self.mock_config = MagicMock(name="mock_config", spec=ParallelConfig) | ||
@pytest.fixture | ||
def mock_config() -> ParallelConfig: | ||
return MagicMock(name="mock_config", spec=ParallelConfig, n_jobs=2) | ||
|
||
|
||
@pytest.fixture | ||
def mock_parallel(mock_config) -> ParallelMap: | ||
return ParallelMap(config=mock_config) | ||
|
||
|
||
def test_config(mock_parallel, mock_config): | ||
assert mock_parallel.config == mock_config | ||
|
||
|
||
def test_reduce_concat(mock_parallel): | ||
results = [[1, 2, 3], [4], [5, 6]] | ||
reduced = mock_parallel.reduce_concat(results) | ||
assert reduced == [1, 2, 3, 4, 5, 6] | ||
|
||
@cached_property | ||
def mock_parallel(self) -> ParallelMap: | ||
return ParallelMap(config=self.mock_config) | ||
|
||
def test_config(self) -> None: | ||
self.assertEqual(self.mock_config, self.mock_parallel.config) | ||
def test_call_when_default(): | ||
def mock_operation(x): | ||
return x * 2 | ||
|
||
def test_reduce_concat(self) -> None: | ||
results = [[1, 2, 3], [4], [5, 6]] | ||
reduced = self.mock_parallel.reduce_concat(results) | ||
self.assertListEqual([1, 2, 3, 4, 5, 6], reduced) | ||
tasks = [1, 2, 3, 4, 5] | ||
mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) | ||
mock_parallel = ParallelMap(config=mock_config) | ||
result = mock_parallel(operation=mock_operation, tasks=tasks) | ||
|
||
def test_call_when_default(self) -> None: | ||
def mock_operation(x): | ||
return x * 2 | ||
assert result == [2, 4, 6, 8, 10] | ||
|
||
tasks = [1, 2, 3, 4, 5] | ||
self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) | ||
result = self.mock_parallel(operation=mock_operation, tasks=tasks) | ||
|
||
self.assertListEqual([2, 4, 6, 8, 10], result) | ||
def test_call_when_bind_kwargs(): | ||
def mock_operation(x, y): | ||
return x * y | ||
|
||
def test_call_when_bind_kwargs(self) -> None: | ||
def mock_operation(x, y): | ||
return x * y | ||
tasks = [1, 2, 3, 4, 5] | ||
mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) | ||
mock_parallel = ParallelMap(config=mock_config) | ||
result = mock_parallel(operation=mock_operation, tasks=tasks, bind_kwargs={"y": 3}) | ||
|
||
tasks = [1, 2, 3, 4, 5] | ||
self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) | ||
result = self.mock_parallel(operation=mock_operation, tasks=tasks, bind_kwargs={"y": 3}) | ||
assert result == [3, 6, 9, 12, 15] | ||
|
||
self.assertListEqual([3, 6, 9, 12, 15], result) | ||
|
||
def test_call_when_reduce_fn(self) -> None: | ||
def mock_operation(x_list): | ||
return [x * 2 for x in x_list] | ||
@pytest.mark.parametrize( | ||
"reduce_fn, expected_result", [(None, [[2, 4], [6, 8], [10]]), (ParallelMap.reduce_concat, [2, 4, 6, 8, 10])] | ||
) | ||
def test_call_when_reduce_fn(reduce_fn, expected_result): | ||
def mock_operation(x_list): | ||
return [x * 2 for x in x_list] | ||
|
||
tasks = [[1, 2], [3, 4], [5]] | ||
self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) | ||
tasks = [[1, 2], [3, 4], [5]] | ||
mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None) | ||
mock_parallel = ParallelMap(config=mock_config) | ||
|
||
with self.subTest(reduce_fn="list"): | ||
result_list = self.mock_parallel(operation=mock_operation, tasks=tasks) | ||
self.assertListEqual([[2, 4], [6, 8], [10]], result_list) | ||
result = mock_parallel(operation=mock_operation, tasks=tasks, reduce_fn=reduce_fn) | ||
assert result == expected_result | ||
|
||
with self.subTest(reduce_fn="concat"): | ||
result_list = self.mock_parallel(operation=mock_operation, tasks=tasks, reduce_fn=ParallelMap.reduce_concat) | ||
self.assertListEqual([2, 4, 6, 8, 10], result_list) | ||
|
||
def test_repr(self) -> None: | ||
self.mock_config = "1234" | ||
self.assertEqual("ParallelMap(config='1234')", repr(self.mock_parallel)) | ||
def test_repr(): | ||
mock_config = "1234" | ||
mock_parallel = ParallelMap(config=mock_config) | ||
assert repr(mock_parallel) == "ParallelMap(config='1234')" | ||
|
||
|
||
if __name__ == "__main__": | ||
unittest.main() | ||
pytest.main() |