diff --git a/src/depiction/clustering/subsampled.py b/src/depiction/clustering/subsampled.py index 9236cae..203e03b 100644 --- a/src/depiction/clustering/subsampled.py +++ b/src/depiction/clustering/subsampled.py @@ -84,3 +84,12 @@ def assign_num_per_cell(self, n_total: int, assignment: dict[int, NDArray[int]]) n_per_cell[np.argmax(n_per_cell)] -= 1 return n_per_cell + + def sample_points(self, array: DataArray, n_samples: int, rng: np.random.Generator) -> DataArray: + """Sample points from the given array such that the points are stratified across the grid cells.""" + assignment = self.assign_points(array) + n_per_cell = self.assign_num_per_cell(n_total=n_samples, assignment=assignment) + sampled_indices = np.concatenate( + [rng.choice(assignment[i], n_per_cell[i], replace=False) for i in range(len(n_per_cell))] + ) + return array.isel(i=sampled_indices) diff --git a/tests/unit/clustering/test_susbsampled.py b/tests/unit/clustering/test_susbsampled.py index 604d12a..0d55c7d 100644 --- a/tests/unit/clustering/test_susbsampled.py +++ b/tests/unit/clustering/test_susbsampled.py @@ -2,6 +2,7 @@ import pandas as pd import pytest from depiction.clustering.subsampled import StratifiedGrid +from numpy.random import default_rng from xarray import DataArray @@ -165,5 +166,52 @@ def test_assign_num_per_cell_some_empty_cells(stratified_grid): np.testing.assert_array_equal(np.array([1, 0, 1, 0, 1, 2, 0, 2]), result) +def test_sample_points(stratified_grid): + # TODO improve this test (was generated) + coords = np.array( + [ + [0, 0], + [0.2, 0.1], + [0.4, 0.2], + [0.6, 0.3], + [0.8, 0.4], + [1, 0.5], + [0.1, 0.6], + [0.3, 0.7], + [0.5, 0.8], + [0.7, 0.9], + [0.9, 1], + ] + ) + + index = pd.MultiIndex.from_arrays([coords[:, 0], coords[:, 1]], names=["x", "y"]) + + array = DataArray(np.arange(11), dims=("i"), coords={"i": index}) + + n_samples = 5 + rng = default_rng(seed=42) + + result = stratified_grid.sample_points(array, n_samples, rng) + + # Check that the number of sampled points is correct + assert len(result) == n_samples + + # Check that the sampled points are a subset of the original points + assert set(result.i.values).issubset(set(array.i.values)) + + # Check that the points are stratified + # This is a bit tricky to test definitively, but we can check that + # points are distributed across different cells + x_coords = result.x.values + y_coords = result.y.values + + assert len(np.unique(np.digitize(x_coords, stratified_grid.edges_x[1:-1]))) > 1 + assert len(np.unique(np.digitize(y_coords, stratified_grid.edges_y[1:-1]))) > 1 + + # Run the sampling multiple times to check for randomness (TODO reconsider) + results = [stratified_grid.sample_points(array, n_samples, rng) for _ in range(10)] + assert not all(np.array_equal(results[0], result) for result in results[1:]) + + if __name__ == "__main__": pytest.main()