Skip to content

Commit

Permalink
add sample_points method
Browse files Browse the repository at this point in the history
  • Loading branch information
leoschwarz committed Jul 22, 2024
1 parent 2e21071 commit cabf5dd
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
9 changes: 9 additions & 0 deletions src/depiction/clustering/subsampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,3 +84,12 @@ def assign_num_per_cell(self, n_total: int, assignment: dict[int, NDArray[int]])
n_per_cell[np.argmax(n_per_cell)] -= 1

return n_per_cell

def sample_points(self, array: DataArray, n_samples: int, rng: np.random.Generator) -> DataArray:
"""Sample points from the given array such that the points are stratified across the grid cells."""
assignment = self.assign_points(array)
n_per_cell = self.assign_num_per_cell(n_total=n_samples, assignment=assignment)
sampled_indices = np.concatenate(
[rng.choice(assignment[i], n_per_cell[i], replace=False) for i in range(len(n_per_cell))]
)
return array.isel(i=sampled_indices)
48 changes: 48 additions & 0 deletions tests/unit/clustering/test_susbsampled.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pandas as pd
import pytest
from depiction.clustering.subsampled import StratifiedGrid
from numpy.random import default_rng
from xarray import DataArray


Expand Down Expand Up @@ -165,5 +166,52 @@ def test_assign_num_per_cell_some_empty_cells(stratified_grid):
np.testing.assert_array_equal(np.array([1, 0, 1, 0, 1, 2, 0, 2]), result)


def test_sample_points(stratified_grid):
# TODO improve this test (was generated)
coords = np.array(
[
[0, 0],
[0.2, 0.1],
[0.4, 0.2],
[0.6, 0.3],
[0.8, 0.4],
[1, 0.5],
[0.1, 0.6],
[0.3, 0.7],
[0.5, 0.8],
[0.7, 0.9],
[0.9, 1],
]
)

index = pd.MultiIndex.from_arrays([coords[:, 0], coords[:, 1]], names=["x", "y"])

array = DataArray(np.arange(11), dims=("i"), coords={"i": index})

n_samples = 5
rng = default_rng(seed=42)

result = stratified_grid.sample_points(array, n_samples, rng)

# Check that the number of sampled points is correct
assert len(result) == n_samples

# Check that the sampled points are a subset of the original points
assert set(result.i.values).issubset(set(array.i.values))

# Check that the points are stratified
# This is a bit tricky to test definitively, but we can check that
# points are distributed across different cells
x_coords = result.x.values
y_coords = result.y.values

assert len(np.unique(np.digitize(x_coords, stratified_grid.edges_x[1:-1]))) > 1
assert len(np.unique(np.digitize(y_coords, stratified_grid.edges_y[1:-1]))) > 1

# Run the sampling multiple times to check for randomness (TODO reconsider)
results = [stratified_grid.sample_points(array, n_samples, rng) for _ in range(10)]
assert not all(np.array_equal(results[0], result) for result in results[1:])


if __name__ == "__main__":
pytest.main()

0 comments on commit cabf5dd

Please sign in to comment.