diff --git a/tests/test_variantdata.py b/tests/test_variantdata.py index e4571222..6a9f922c 100644 --- a/tests/test_variantdata.py +++ b/tests/test_variantdata.py @@ -133,146 +133,141 @@ def test_sgkit_individual_metadata_not_clobbered(tmp_path): @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") -def test_sgkit_dataset_accessors(tmp_path): - ts, zarr_path = tsutil.make_ts_and_zarr( - tmp_path, add_optional=True, shuffle_alleles=False - ) - samples = tsinfer.VariantData( - zarr_path, "variant_ancestral_allele", sites_time="sites_time" - ) - ds = sgkit.load_dataset(zarr_path) - - assert samples.format_name == "tsinfer-variant-data" - assert samples.format_version == (0, 1) - assert samples.finalised - assert samples.sequence_length == ts.sequence_length + 1337 - assert samples.num_sites == ts.num_sites - assert samples.sites_metadata_schema == ts.tables.sites.metadata_schema.schema - assert samples.sites_metadata == [site.metadata for site in ts.sites()] - assert np.array_equal(samples.sites_time, np.arange(ts.num_sites) / ts.num_sites) - assert np.array_equal(samples.sites_position, ts.tables.sites.position) - for alleles, v in zip(samples.sites_alleles, ts.variants()): +@pytest.mark.parametrize("in_mem", [True, False]) +def test_variantdata_accessors(tmp_path, in_mem): + path = None if in_mem else tmp_path + ts, data = tsutil.make_ts_and_zarr(path, add_optional=True, shuffle_alleles=False) + vd = tsinfer.VariantData(data, "variant_ancestral_allele", sites_time="sites_time") + ds = data if in_mem else sgkit.load_dataset(data) + + assert vd.format_name == "tsinfer-variant-data" + assert vd.format_version == (0, 1) + assert vd.finalised + assert vd.sequence_length == ts.sequence_length + 1337 + assert vd.num_sites == ts.num_sites + assert vd.sites_metadata_schema == ts.tables.sites.metadata_schema.schema + assert vd.sites_metadata == [site.metadata for site in ts.sites()] + assert np.array_equal(vd.sites_time, np.arange(ts.num_sites) / ts.num_sites) + assert np.array_equal(vd.sites_position, ts.tables.sites.position) + for alleles, v in zip(vd.sites_alleles, ts.variants()): # sgkit alleles are padded to be rectangular assert np.all(alleles[: len(v.alleles)] == v.alleles) assert np.all(alleles[len(v.alleles) :] == "") - assert np.array_equal(samples.sites_select, np.ones(ts.num_sites, dtype=bool)) + assert np.array_equal(vd.sites_select, np.ones(ts.num_sites, dtype=bool)) assert np.array_equal( - samples.sites_ancestral_allele, np.zeros(ts.num_sites, dtype=np.int8) + vd.sites_ancestral_allele, np.zeros(ts.num_sites, dtype=np.int8) ) - assert np.array_equal(samples.sites_genotypes, ts.genotype_matrix()) + assert np.array_equal(vd.sites_genotypes, ts.genotype_matrix()) assert np.array_equal( - samples.provenances_timestamp, ["2021-01-01T00:00:00", "2021-01-02T00:00:00"] + vd.provenances_timestamp, ["2021-01-01T00:00:00", "2021-01-02T00:00:00"] ) - assert samples.provenances_record == [{"foo": 1}, {"foo": 2}] - assert samples.num_samples == ts.num_samples + assert vd.provenances_record == [{"foo": 1}, {"foo": 2}] + assert vd.num_samples == ts.num_samples assert np.array_equal( - samples.samples_individual, np.repeat(np.arange(ts.num_samples // 3), 3) + vd.samples_individual, np.repeat(np.arange(ts.num_samples // 3), 3) ) - assert samples.metadata_schema == tsutil.example_schema("example").schema - assert samples.metadata == ts.tables.metadata + assert vd.metadata_schema == tsutil.example_schema("example").schema + assert vd.metadata == ts.tables.metadata assert ( - samples.populations_metadata_schema - == ts.tables.populations.metadata_schema.schema + vd.populations_metadata_schema == ts.tables.populations.metadata_schema.schema ) - assert samples.populations_metadata == [pop.metadata for pop in ts.populations()] - assert samples.num_individuals == ts.num_individuals + assert vd.populations_metadata == [pop.metadata for pop in ts.populations()] + assert vd.num_individuals == ts.num_individuals assert np.array_equal( - samples.individuals_time, np.arange(ts.num_individuals, dtype=np.float32) + vd.individuals_time, np.arange(ts.num_individuals, dtype=np.float32) ) assert ( - samples.individuals_metadata_schema - == ts.tables.individuals.metadata_schema.schema + vd.individuals_metadata_schema == ts.tables.individuals.metadata_schema.schema ) - assert samples.individuals_metadata == [ + assert vd.individuals_metadata == [ {"variant_data_sample_id": sample_id, **ind.metadata} - for ind, sample_id in zip(ts.individuals(), ds["sample_id"].values) + for ind, sample_id in zip(ts.individuals(), ds.sample_id[:]) ] assert np.array_equal( - samples.individuals_location, + vd.individuals_location, np.tile(np.array([["0", "1"]], dtype="float32"), (ts.num_individuals, 1)), ) assert np.array_equal( - samples.individuals_population, np.zeros(ts.num_individuals, dtype="int32") + vd.individuals_population, np.zeros(ts.num_individuals, dtype="int32") ) assert np.array_equal( - samples.individuals_flags, + vd.individuals_flags, np.random.RandomState(42).randint( 0, 2_000_000, ts.num_individuals, dtype="int32" ), ) # Need to shuffle for the ancestral allele test - ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path, add_optional=True) - samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") + ts, data = tsutil.make_ts_and_zarr(path, add_optional=True) + vd = tsinfer.VariantData(data, "variant_ancestral_allele") for i in range(ts.num_sites): assert ( - samples.sites_alleles[i][samples.sites_ancestral_allele[i]] + vd.sites_alleles[i][vd.sites_ancestral_allele[i]] == ts.site(i).ancestral_state ) @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") -def test_sgkit_accessors_defaults(tmp_path): - ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") - ds = sgkit.load_dataset(zarr_path) +@pytest.mark.parametrize("in_mem", [True, False]) +def test_variantdata_accessors_defaults(tmp_path, in_mem): + path = None if in_mem else tmp_path + ts, data = tsutil.make_ts_and_zarr(path) + vdata = tsinfer.VariantData(data, "variant_ancestral_allele") + ds = data if in_mem else sgkit.load_dataset(data) default_schema = tskit.MetadataSchema.permissive_json().schema - assert samples.sequence_length == ts.sequence_length - assert samples.sites_metadata_schema == default_schema - assert samples.sites_metadata == [{} for _ in range(ts.num_sites)] - for time in samples.sites_time: + assert vdata.sequence_length == ts.sequence_length + assert vdata.sites_metadata_schema == default_schema + assert vdata.sites_metadata == [{} for _ in range(ts.num_sites)] + for time in vdata.sites_time: assert tskit.is_unknown_time(time) - assert np.array_equal(samples.sites_select, np.ones(ts.num_sites, dtype=bool)) - assert np.array_equal(samples.provenances_timestamp, []) - assert np.array_equal(samples.provenances_record, []) - assert samples.metadata_schema == default_schema - assert samples.metadata == {} - assert samples.populations_metadata_schema == default_schema - assert samples.populations_metadata == [] - assert samples.individuals_metadata_schema == default_schema - assert samples.individuals_metadata == [ - {"variant_data_sample_id": sample_id} for sample_id in ds["sample_id"].values + assert np.array_equal(vdata.sites_select, np.ones(ts.num_sites, dtype=bool)) + assert np.array_equal(vdata.provenances_timestamp, []) + assert np.array_equal(vdata.provenances_record, []) + assert vdata.metadata_schema == default_schema + assert vdata.metadata == {} + assert vdata.populations_metadata_schema == default_schema + assert vdata.populations_metadata == [] + assert vdata.individuals_metadata_schema == default_schema + assert vdata.individuals_metadata == [ + {"variant_data_sample_id": sample_id} for sample_id in ds.sample_id[:] ] - for time in samples.individuals_time: + for time in vdata.individuals_time: assert tskit.is_unknown_time(time) assert np.array_equal( - samples.individuals_location, np.array([[]] * ts.num_individuals, dtype=float) + vdata.individuals_location, np.array([[]] * ts.num_individuals, dtype=float) ) assert np.array_equal( - samples.individuals_population, np.full(ts.num_individuals, tskit.NULL) + vdata.individuals_population, np.full(ts.num_individuals, tskit.NULL) ) assert np.array_equal( - samples.individuals_flags, np.zeros(ts.num_individuals, dtype=int) + vdata.individuals_flags, np.zeros(ts.num_individuals, dtype=int) ) @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") -def test_variantdata_sites_time_default(tmp_path): - ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) - samples = tsinfer.VariantData(zarr_path, "variant_ancestral_allele") +def test_variantdata_sites_time_default(): + ts, data = tsutil.make_ts_and_zarr() + vdata = tsinfer.VariantData(data, "variant_ancestral_allele") assert ( - np.all(np.isnan(samples.sites_time)) - and samples.sites_time.size == samples.num_sites + np.all(np.isnan(vdata.sites_time)) and vdata.sites_time.size == vdata.num_sites ) @pytest.mark.skipif(sys.platform == "win32", reason="No cyvcf2 on windows") -def test_variantdata_sites_time_array(tmp_path): - ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path) +def test_variantdata_sites_time_array(): + ts, data = tsutil.make_ts_and_zarr() sites_time = np.arange(ts.num_sites) - samples = tsinfer.VariantData( - zarr_path, "variant_ancestral_allele", sites_time=sites_time - ) - assert np.array_equal(samples.sites_time, sites_time) + vdata = tsinfer.VariantData(data, "variant_ancestral_allele", sites_time=sites_time) + assert np.array_equal(vdata.sites_time, sites_time) wrong_length_sites_time = np.arange(ts.num_sites + 1) with pytest.raises( ValueError, match="Sites time array must be the same length as the number of selected sites", ): tsinfer.VariantData( - zarr_path, + data, "variant_ancestral_allele", sites_time=wrong_length_sites_time, ) @@ -302,17 +297,17 @@ def test_sgkit_variant_mask(self, tmp_path, sites): for i in sites: sites_mask[i] = False tsutil.add_array_to_dataset("variant_mask_42", sites_mask, zarr_path) - samples = tsinfer.VariantData( + vdata = tsinfer.VariantData( zarr_path, "variant_ancestral_allele", site_mask="variant_mask_42", ) - assert samples.num_sites == len(sites) - assert np.array_equal(samples.sites_select, ~sites_mask) + assert vdata.num_sites == len(sites) + assert np.array_equal(vdata.sites_select, ~sites_mask) assert np.array_equal( - samples.sites_position, ts.tables.sites.position[~sites_mask] + vdata.sites_position, ts.tables.sites.position[~sites_mask] ) - inf_ts = tsinfer.infer(samples) + inf_ts = tsinfer.infer(vdata) assert np.array_equal( ts.genotype_matrix()[~sites_mask], inf_ts.genotype_matrix() ) @@ -675,6 +670,14 @@ def test_sgkit_ancestor(small_sd_fixture, tmp_path): class TestVariantDataErrors: + def test_bad_zarr_spec(self): + ds = zarr.group() + ds["call_genotype"] = zarr.array(np.zeros(10, dtype=np.int8)) + with pytest.raises( + ValueError, match="Expecting a VCF Zarr object with 3D call_genotype array" + ): + tsinfer.VariantData(ds, np.zeros(10, dtype="