Skip to content

Commit

Permalink
Deal with multiple contigs and sequence lengths
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong authored and mergify[bot] committed Feb 25, 2025
1 parent 1aa0233 commit d8dafc7
Show file tree
Hide file tree
Showing 5 changed files with 201 additions and 34 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,9 @@
- If a mismatch ratio is provided to the `infer` command, it only applies during the
`match_samples` phase ({issue}`980`, {pr}`981`, {user}`hyanwong`)

- Get the `sequence_length` of the contig associated with the unmasked sites,
if contig lengths are provided ({pr}`964`, {user}`hyanwong`, {user}`benjeffery`)

**Fixes**

- Properly account for "N" as an unknown ancestral state, and ban "" from being
Expand Down
8 changes: 4 additions & 4 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -107,10 +107,10 @@ onto branches by {meth}`parsimony<tskit.Tree.map_mutations>`.
It is also possible to *completely* exclude sites and samples, by specifing a boolean
`site_mask` and/or a `sample_mask` when creating the `VariantData` object. Sites or samples with
a mask value of `True` will be completely omitted both from inference and the final tree sequence.
This can be useful, for example, if your VCF file contains multiple chromosomes (in which case
`tsinfer` will need to be run separately on each chromosome) or if you wish to select only a subset
of the chromosome for inference (e.g. to reduce computational load). If a `site_mask` is provided,
note that the ancestral alleles array only specifies alleles for the unmasked sites.
This can be useful, for example, if you wish to select only a subset of the chromosome for
inference, e.g. to reduce computational load. You can also use it to subset inference to a
particular contig, if your dataset contains multiple contigs. Note that if a `site_mask` is provided,
the ancestral states array should only specify alleles for the unmasked sites.

Below, for instance, is an example of including only sites up to position six in the contig
labelled "chr1" in the `example_data.vcz` file:
Expand Down
156 changes: 142 additions & 14 deletions tests/test_variantdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from tsinfer import formats


def ts_to_dataset(ts, chunks=None, samples=None):
def ts_to_dataset(ts, chunks=None, samples=None, contigs=None):
"""
# From https://github.com/sgkit-dev/sgkit/blob/main/sgkit/tests/test_popgen.py#L63
Convert the specified tskit tree sequence into an sgkit dataset.
Expand All @@ -63,7 +63,7 @@ def ts_to_dataset(ts, chunks=None, samples=None):
genotypes = np.expand_dims(genotypes, axis=2)

ds = sgkit.create_genotype_call_dataset(
variant_contig_names=["1"],
variant_contig_names=["1"] if contigs is None else contigs,
variant_contig=np.zeros(len(tables.sites), dtype=int),
variant_position=tables.sites.position.astype(int),
variant_allele=alleles,
Expand Down Expand Up @@ -145,7 +145,7 @@ def test_variantdata_accessors(tmp_path, in_mem):
assert vd.format_name == "tsinfer-variant-data"
assert vd.format_version == (0, 1)
assert vd.finalised
assert vd.sequence_length == ts.sequence_length + 1337
assert vd.sequence_length == ts.sequence_length
assert vd.num_sites == ts.num_sites
assert vd.sites_metadata_schema == ts.tables.sites.metadata_schema.schema
assert vd.sites_metadata == [site.metadata for site in ts.sites()]
Expand Down Expand Up @@ -218,11 +218,7 @@ def test_variantdata_accessors_defaults(tmp_path, in_mem):
ds = data if in_mem else sgkit.load_dataset(data)

default_schema = tskit.MetadataSchema.permissive_json().schema
with pytest.warns(
UserWarning,
match="`sequence_length` was not found as an attribute in the dataset",
):
assert vdata.sequence_length == ts.sequence_length
assert vdata.sequence_length == ts.sequence_length
assert vdata.sites_metadata_schema == default_schema
assert vdata.sites_metadata == [{} for _ in range(ts.num_sites)]
for time in vdata.sites_time:
Expand Down Expand Up @@ -299,18 +295,116 @@ def test_simulate_genotype_call_dataset(tmp_path):
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
ds = ts_to_dataset(ts)
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
ds.to_zarr(tmp_path, mode="w")
sd = tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
ts = tsinfer.infer(sd)
for v, ds_v, sd_v in zip(ts.variants(), ds.call_genotype, sd.sites_genotypes):
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
ts = tsinfer.infer(vdata)
for v, ds_v, vd_v in zip(ts.variants(), ds.call_genotype, vdata.sites_genotypes):
assert np.all(v.genotypes == ds_v.values.flatten())
assert np.all(v.genotypes == sd_v)
assert np.all(v.genotypes == vd_v)


def test_simulate_genotype_call_dataset_length(tmp_path):
# create_genotype_call_dataset does not save contig lengths
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
ds = ts_to_dataset(ts)
assert "contig_length" not in ds
ds.to_zarr(tmp_path, mode="w")
vdata = tsinfer.VariantData(tmp_path, ds["variant_allele"][:, 0].values.astype(str))
assert vdata.sequence_length == ts.sites_position[-1] + 1

vdata = tsinfer.VariantData(
tmp_path, ds["variant_allele"][:, 0].values.astype(str), sequence_length=1337
)
assert vdata.sequence_length == 1337


class TestMultiContig:
def make_two_ts_dataset(self, path):
# split ts into 2; put them as different contigs in the same dataset
ts = msprime.sim_ancestry(4, sequence_length=1000, random_seed=123)
ts = msprime.sim_mutations(ts, rate=2e-3, random_seed=123)
split_at_site = 7
assert ts.num_sites > 10
site_break = ts.site(split_at_site).position
ts1 = ts.keep_intervals([(0, site_break)]).rtrim()
ts2 = ts.keep_intervals([(site_break, ts.sequence_length)]).ltrim()
ds = ts_to_dataset(ts, contigs=["chr1", "chr2"])
ds.update({"variant_ancestral_allele": ds["variant_allele"][:, 0]})
variant_contig = ds["variant_contig"][:]
variant_contig[split_at_site:] = 1
ds.update({"variant_contig": variant_contig})
variant_position = ds["variant_position"].values
variant_position[split_at_site:] -= int(site_break)
ds.update({"variant_position": ds["variant_position"]})
ds.update(
{"contig_length": np.array([ts1.sequence_length, ts2.sequence_length])}
)
ds.to_zarr(path, mode="w")
return ts1, ts2

def test_unmasked(self, tmp_path):
self.make_two_ts_dataset(tmp_path)
with pytest.raises(ValueError, match=r'multiple contigs \("chr1", "chr2"\)'):
tsinfer.VariantData(tmp_path, "variant_ancestral_allele")

def test_mask(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
vdata = tsinfer.VariantData(
tmp_path,
"variant_ancestral_allele",
site_mask=np.array(ts1.num_sites * [True] + ts2.num_sites * [False]),
)
assert np.all(ts2.sites_position == vdata.sites_position)
assert vdata.contig_id == "chr2"
assert vdata.sequence_length == ts2.sequence_length

@pytest.mark.parametrize("contig_id", ["chr1", "chr2"])
def test_multi_contig(self, contig_id, tmp_path):
tree_seqs = {}
tree_seqs["chr1"], tree_seqs["chr2"] = self.make_two_ts_dataset(tmp_path)
with pytest.raises(ValueError, match="multiple contigs"):
vdata = tsinfer.VariantData(tmp_path, "variant_ancestral_allele")
root = zarr.open(tmp_path)
mask = root["variant_contig"][:] == (1 if contig_id == "chr1" else 0)
vdata = tsinfer.VariantData(
tmp_path, "variant_ancestral_allele", site_mask=mask
)
assert np.all(tree_seqs[contig_id].sites_position == vdata.sites_position)
assert vdata.contig_id == contig_id
assert vdata._contig_index == (0 if contig_id == "chr1" else 1)
assert vdata.sequence_length == tree_seqs[contig_id].sequence_length

def test_mixed_contigs_error(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
mask = np.ones(ts1.num_sites + ts2.num_sites)
# Select two varaints, one from each contig
mask[0] = False
mask[-1] = False
with pytest.raises(ValueError, match="multiple contigs"):
tsinfer.VariantData(
tmp_path,
"variant_ancestral_allele",
site_mask=mask,
)

def test_no_variant_contig(self, tmp_path):
ts1, ts2 = self.make_two_ts_dataset(tmp_path)
root = zarr.open(tmp_path)
del root["variant_contig"]
mask = np.ones(ts1.num_sites + ts2.num_sites)
mask[0] = False
vdata = tsinfer.VariantData(
tmp_path, "variant_ancestral_allele", site_mask=mask
)
assert vdata.sequence_length == ts1.sites_position[0] + 1
assert vdata.contig_id is None
assert vdata._contig_index is None


@pytest.mark.skipif(sys.platform == "win32", reason="File permission errors on Windows")
class TestSgkitMask:
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0], []])
@pytest.mark.parametrize("sites", [[1, 2, 3, 5, 9, 27], [0]])
def test_sgkit_variant_mask(self, tmp_path, sites):
ts, zarr_path = tsutil.make_ts_and_zarr(tmp_path)
ds = sgkit.load_dataset(zarr_path)
Expand Down Expand Up @@ -823,6 +917,20 @@ def test_bad_ancestral_state(self, tmp_path):
with pytest.raises(ValueError, match="cannot contain empty strings"):
tsinfer.VariantData(path, ancestral_state)

def test_ancestral_state_len_not_same_as_mask(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
sgkit.save_dataset(ds, path)
ancestral_state = ds["variant_allele"][:, 0].values.astype(str)
site_mask = np.zeros(ds.sizes["variants"], dtype=bool)
site_mask[0] = True
with pytest.raises(
ValueError,
match="Ancestral state array must be the same length as the number of"
" selected sites",
):
tsinfer.VariantData(path, ancestral_state, site_mask=site_mask)

def test_empty_alleles_not_at_end(self, tmp_path):
path = tmp_path / "data.zarr"
ds = self.simulate_genotype_call_dataset(n_variant=3, n_sample=3, n_ploidy=1)
Expand Down Expand Up @@ -854,3 +962,23 @@ def test_unimplemented_from_tree_sequence(self):
# Requires e.g. https://github.com/tskit-dev/tsinfer/issues/924
with pytest.raises(NotImplementedError):
tsinfer.VariantData.from_tree_sequence(None)

def test_all_masked(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
sgkit.save_dataset(ds, path)
with pytest.raises(ValueError, match="All sites have been masked out"):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), site_mask=np.ones(3, bool)
)

def test_missing_sites_time(self, tmp_path):
path = tmp_path / "data.zarr"
ds = sgkit.simulate_genotype_call_dataset(n_variant=3, n_sample=3, phased=True)
sgkit.save_dataset(ds, path)
with pytest.raises(
ValueError, match="The sites time array XX was not found in the dataset"
):
tsinfer.VariantData(
path, ds["variant_allele"][:, 0].astype(str), sites_time="XX"
)
5 changes: 0 additions & 5 deletions tests/tsutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -339,11 +339,6 @@ def _make_ts_and_zarr(path, add_optional=False, shuffle_alleles=True):
)

if add_optional:
add_attribute_to_dataset(
"sequence_length",
ts.sequence_length + 1337,
path / "data.zarr",
)
sites_md = tables.sites.metadata
sites_md_offset = tables.sites.metadata_offset
add_array_to_dataset(
Expand Down
63 changes: 52 additions & 11 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,11 @@ class VariantData(SampleData):
reasonable approximation to the relative order of ancestors used for inference.
Time values are ignored for sites not used in inference, such as singletons,
sites with more than two alleles, or sites with an unknown ancestral state.
:param int sequence_length: An integer specifying the resulting `sequence_length`
attribute of the output tree sequence. If not specified the `contig_length`
attribute from the undelying zarr store for the contig of the selected variants.
is used. If that is not present then the maximum position plus one of the used
variants is used.
"""

FORMAT_NAME = "tsinfer-variant-data"
Expand All @@ -2351,7 +2356,11 @@ def __init__(
sample_mask=None,
site_mask=None,
sites_time=None,
sequence_length=None,
):
self._sequence_length = sequence_length
self._contig_index = None
self._contig_id = None
try:
if len(path_or_zarr.call_genotype.shape) == 3:
# Assumed to be a VCF Zarr hierarchy
Expand Down Expand Up @@ -2384,9 +2393,16 @@ def __init__(
raise ValueError(
"Site mask array must be the same length as the number of unmasked sites"
)

# We negate the mask as it is much easier in numpy to have True=keep
self.sites_select = ~site_mask.astype(bool)

if np.sum(self.sites_select) == 0:
raise ValueError(
"All sites have been masked out, at least one value"
"must be 'False' in the site mask"
)

if sample_mask is None:
sample_mask = np.full(self._num_individuals_before_mask, False, dtype=bool)
elif isinstance(sample_mask, np.ndarray):
Expand Down Expand Up @@ -2415,6 +2431,22 @@ def __init__(
" zarr dataset, indicating that all the genotypes are"
" unphased"
)

if "variant_contig" in self.data:
used_contigs = self.data.variant_contig[:][self.sites_select]
self._contig_index = used_contigs[0]
self._contig_id = self.data.contig_id[self._contig_index]

if np.any(used_contigs != self._contig_index):
contig_names = ", ".join(
f'"{self.data.contig_id[c]}"' for c in np.unique(used_contigs)
)
raise ValueError(
f"Sites belong to multiple contigs ({contig_names}). "
"Please restrict sites to one contig using the sites_mask argument."
"e.g. `mask=zarr_group['variant_contig'] != wanted_index`"
)

if np.any(np.diff(self.sites_position) <= 0):
raise ValueError(
"Values taken from the variant_position array are not strictly "
Expand All @@ -2436,7 +2468,7 @@ def __init__(
self._sites_time = self.data[sites_time][:][self.sites_select]
except KeyError:
raise ValueError(
f"The sites time {sites_time} was not found" f" in the dataset."
f"The sites time array {sites_time} was not found in the dataset"
)

if isinstance(ancestral_state, np.ndarray):
Expand Down Expand Up @@ -2519,16 +2551,25 @@ def finalised(self):

@functools.cached_property
def sequence_length(self):
try:
return self.data.attrs["sequence_length"]
except KeyError:
warnings.warn(
"`sequence_length` was not found as an attribute in the dataset, so"
" the largest position has been used. It can be set with"
" ds.attrs['sequence_length'] = 1337; ds.to_zarr('path/to/store',"
" mode='a')"
)
return int(np.max(self.data["variant_position"])) + 1
"""
The sequence length of the contig associated with sites used in the dataset.
If set manually then that value is used else if the dataset has recorded
contig lengths use that else the length is calculated from the maximum
variant position plus one.
"""
if self._sequence_length is not None:
return self._sequence_length
if self._contig_index is not None and "contig_length" in self.data:
return self.data.contig_length[self._contig_index]
return int(np.max(self.sites_position)) + 1

@property
def contig_id(self):
"""
The contig ID (name) for all used sites, or None if no
contig IDs were present in the zarr dataset
"""
return self._contig_id

@property
def num_sites(self):
Expand Down

0 comments on commit d8dafc7

Please sign in to comment.