Skip to content

Commit

Permalink
convert to pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Aug 13, 2024
1 parent 255d39b commit ae890c4
Showing 1 changed file with 52 additions and 43 deletions.
95 changes: 52 additions & 43 deletions tests/unit/parallel_ops/test_parallel_map.py
Original file line number Diff line number Diff line change
@@ -1,66 +1,75 @@
import unittest
from functools import cached_property
from unittest.mock import MagicMock

import pytest

from depiction.parallel_ops import ParallelConfig
from depiction.parallel_ops.parallel_map import ParallelMap


class TestParallelMap(unittest.TestCase):
def setUp(self) -> None:
self.mock_config = MagicMock(name="mock_config", spec=ParallelConfig)
@pytest.fixture
def mock_config() -> ParallelConfig:
return MagicMock(name="mock_config", spec=ParallelConfig, n_jobs=2)


@pytest.fixture
def mock_parallel(mock_config) -> ParallelMap:
return ParallelMap(config=mock_config)


def test_config(mock_parallel, mock_config):
assert mock_parallel.config == mock_config


def test_reduce_concat(mock_parallel):
results = [[1, 2, 3], [4], [5, 6]]
reduced = mock_parallel.reduce_concat(results)
assert reduced == [1, 2, 3, 4, 5, 6]

@cached_property
def mock_parallel(self) -> ParallelMap:
return ParallelMap(config=self.mock_config)

def test_config(self) -> None:
self.assertEqual(self.mock_config, self.mock_parallel.config)
def test_call_when_default():
def mock_operation(x):
return x * 2

def test_reduce_concat(self) -> None:
results = [[1, 2, 3], [4], [5, 6]]
reduced = self.mock_parallel.reduce_concat(results)
self.assertListEqual([1, 2, 3, 4, 5, 6], reduced)
tasks = [1, 2, 3, 4, 5]
mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None)
mock_parallel = ParallelMap(config=mock_config)
result = mock_parallel(operation=mock_operation, tasks=tasks)

def test_call_when_default(self) -> None:
def mock_operation(x):
return x * 2
assert result == [2, 4, 6, 8, 10]

tasks = [1, 2, 3, 4, 5]
self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None)
result = self.mock_parallel(operation=mock_operation, tasks=tasks)

self.assertListEqual([2, 4, 6, 8, 10], result)
def test_call_when_bind_kwargs():
def mock_operation(x, y):
return x * y

def test_call_when_bind_kwargs(self) -> None:
def mock_operation(x, y):
return x * y
tasks = [1, 2, 3, 4, 5]
mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None)
mock_parallel = ParallelMap(config=mock_config)
result = mock_parallel(operation=mock_operation, tasks=tasks, bind_kwargs={"y": 3})

tasks = [1, 2, 3, 4, 5]
self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None)
result = self.mock_parallel(operation=mock_operation, tasks=tasks, bind_kwargs={"y": 3})
assert result == [3, 6, 9, 12, 15]

self.assertListEqual([3, 6, 9, 12, 15], result)

def test_call_when_reduce_fn(self) -> None:
def mock_operation(x_list):
return [x * 2 for x in x_list]
@pytest.mark.parametrize(
"reduce_fn, expected_result", [(None, [[2, 4], [6, 8], [10]]), (ParallelMap.reduce_concat, [2, 4, 6, 8, 10])]
)
def test_call_when_reduce_fn(reduce_fn, expected_result):
def mock_operation(x_list):
return [x * 2 for x in x_list]

tasks = [[1, 2], [3, 4], [5]]
self.mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None)
tasks = [[1, 2], [3, 4], [5]]
mock_config = ParallelConfig(n_jobs=3, verbose=0, task_size=None)
mock_parallel = ParallelMap(config=mock_config)

with self.subTest(reduce_fn="list"):
result_list = self.mock_parallel(operation=mock_operation, tasks=tasks)
self.assertListEqual([[2, 4], [6, 8], [10]], result_list)
result = mock_parallel(operation=mock_operation, tasks=tasks, reduce_fn=reduce_fn)
assert result == expected_result

with self.subTest(reduce_fn="concat"):
result_list = self.mock_parallel(operation=mock_operation, tasks=tasks, reduce_fn=ParallelMap.reduce_concat)
self.assertListEqual([2, 4, 6, 8, 10], result_list)

def test_repr(self) -> None:
self.mock_config = "1234"
self.assertEqual("ParallelMap(config='1234')", repr(self.mock_parallel))
def test_repr():
mock_config = "1234"
mock_parallel = ParallelMap(config=mock_config)
assert repr(mock_parallel) == "ParallelMap(config='1234')"


if __name__ == "__main__":
unittest.main()
pytest.main()

0 comments on commit ae890c4

Please sign in to comment.