Skip to content

Commit

Permalink
Use struct for node metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
hyanwong committed Mar 5, 2022
1 parent 99cad13 commit 3786aea
Show file tree
Hide file tree
Showing 3 changed files with 44 additions and 17 deletions.
14 changes: 7 additions & 7 deletions tests/test_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1094,9 +1094,11 @@ def test_from_standard_tree_sequence(self):
assert i1.flags == i2.flags
assert tsutil.json_metadata_is_subset(i1.metadata, i2.metadata)
# Unless inference is perfect, internal nodes may differ, but sample nodes
# should be identical
# should be identical. Node metadata is not transferred, however, and a tsinfer-
# specific node metadata schema is used (where empty is None rather than b"")
assert ts.table_metadata_schemas.node == tsinfer.formats.node_metadata_schema()
for n1, n2 in zip(ts.samples(), ts_inferred.samples()):
assert ts.node(n1) == ts_inferred.node(n2)
assert ts.node(n1).replace(metadata=None) == ts_inferred.node(n2)
# Sites can have metadata added by the inference process, but inferred site
# metadata should always include all the metadata in the original ts
for s1, s2 in zip(ts.sites(), ts_inferred.sites()):
Expand Down Expand Up @@ -1586,7 +1588,7 @@ def verify(self, sample_data, mismatch_ratio=None, recombination_rate=None):
ancestors_time = ancestor_data.ancestors_time[:]
num_ancestor_nodes = 0
for n in ancestors_ts.nodes():
md = json.loads(n.metadata) if n.metadata else {}
md = n.metadata if n.metadata else {}
if tsinfer.is_pc_ancestor(n.flags):
assert not ("ancestor_data_id" in md)
else:
Expand Down Expand Up @@ -3114,8 +3116,7 @@ def verify_augmented_ancestors(
node = t2.nodes[m + j]
assert node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR
assert node.time == 1
metadata = json.loads(node.metadata.decode())
assert node_id == metadata["sample_data_id"]
assert node_id == node.metadata["sample_data_id"]

t2.nodes.truncate(len(t1.nodes))
# Adding and subtracting 1 can lead to small diffs, so we compare
Expand Down Expand Up @@ -3265,8 +3266,7 @@ def verify_example(self, full_subset, samples, ancestors, path_compression):
num_sample_ancestors = 0
for node in final_ts.nodes():
if node.flags == tsinfer.NODE_IS_SAMPLE_ANCESTOR:
metadata = json.loads(node.metadata.decode())
assert metadata["sample_data_id"] in subset
assert node.metadata["sample_data_id"] in subset
num_sample_ancestors += 1
assert expected_sample_ancestors == num_sample_ancestors
tsinfer.verify(samples, final_ts.simplify())
Expand Down
26 changes: 26 additions & 0 deletions tsinfer/formats.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,32 @@ def permissive_json_schema():
}


def node_metadata_schema():
# This is fixed by tsinfer: users cannot add to the node metadata
return tskit.MetadataSchema(
{
"codec": "struct",
"type": ["object", "null"],
"properties": {
"ancestor_data_id": {
"description": "",
"type": "integer",
"binaryFormat": "i",
"default": -1,
},
"sample_data_id": {
"description": "Date of sample collection in ISO format",
"type": "integer",
"binaryFormat": "i",
"default": -1,
},
},
"required": ["ancestor_data_id", "sample_data_id"],
"additionalProperties": False,
}
)


def np_obj_equal(np_obj_array1, np_obj_array2):
"""
A replacement for np.array_equal to test equality of numpy arrays that
Expand Down
21 changes: 11 additions & 10 deletions tsinfer/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,12 +1411,7 @@ def get_ancestors_tree_sequence(self):
pc_ancestors = is_pc_ancestor(flags)
tables.nodes.set_columns(flags=flags, time=times)

# # FIXME we should do this as a struct codec?
# dict_schema = permissive_json_schema()
# dict_schema = add_to_schema(dict_schema, "ancestor_data_id",
# {"type": "integer"})
# schema = tskit.MetadataSchema(dict_schema)
# tables.nodes.schema = schema
tables.nodes.metadata_schema = formats.node_metadata_schema()

# Add metadata for any non-PC node, pointing to the original ancestor
metadata = []
Expand All @@ -1425,7 +1420,11 @@ def get_ancestors_tree_sequence(self):
if is_pc:
metadata.append(b"")
else:
metadata.append(_encode_raw_metadata({"ancestor_data_id": ancestor}))
metadata.append(
tables.nodes.metadata_schema.validate_and_encode_row(
{"ancestor_data_id": ancestor, "sample_data_id": tskit.NULL}
)
)
ancestor += 1
tables.nodes.packset_metadata(metadata)
left, right, parent, child = tsb.dump_edges()
Expand Down Expand Up @@ -1471,6 +1470,7 @@ def store_output(self):
tables = tskit.TableCollection(
sequence_length=self.ancestor_data.sequence_length
)
tables.nodes.metadata_schema = formats.node_metadata_schema()
ts = tables.tree_sequence()
return ts

Expand Down Expand Up @@ -1830,9 +1830,10 @@ def get_augmented_ancestors_tree_sequence(self, sample_indexes):
tables.nodes.add_row(
flags=constants.NODE_IS_SAMPLE_ANCESTOR,
time=times[j],
metadata=_encode_raw_metadata(
{"sample_data_id": int(sample_indexes[s])}
),
metadata={
"ancestor_data_id": tskit.NULL,
"sample_data_id": int(sample_indexes[s]),
},
)
s += 1
else:
Expand Down

0 comments on commit 3786aea

Please sign in to comment.