Skip to content

Commit

Permalink
When creating an mlc.Metadata object, share the graph with all node…
Browse files Browse the repository at this point in the history
…s. (#713)
  • Loading branch information
marcenacp authored Jul 16, 2024
1 parent 09f804b commit 2cb765e
Show file tree
Hide file tree
Showing 3 changed files with 29 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,15 @@ def parent(self) -> "Node | None":
@property
def predecessors(self) -> set["Node"]:
"""Predecessors in the structure graph."""
return set(self.ctx.graph.predecessors(self)) # pytype: disable=bad-return-type
try:
predecessors = self.ctx.graph.predecessors(self)
return set(predecessors) # pytype: disable=bad-return-type
except KeyError as e:
raise KeyError(
f"Could not find node '{self.id}' in the graph. Make sure to build a"
" full mlcroissant metadata object (mlc.Metadata) wrapping all the"
" FileSets/FileObjects/RecordSets/Fields."
)

@property
def recursive_predecessors(self) -> set["Node"]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,10 @@ def __post_init__(self):
self.ctx, self.ctx.conforms_to
)

# Share the structure graph in the context
for node in self.nodes():
node.ctx.graph = self.ctx.graph

def to_json(self) -> Json:
"""Converts the `Metadata` to JSON."""
context = self.ctx.rdf.context
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from mlcroissant._src.core.issues import ValidationError
from mlcroissant._src.structure_graph.base_node import Node
from mlcroissant._src.structure_graph.nodes.creative_work import CreativeWork
from mlcroissant._src.structure_graph.nodes.field import Field
from mlcroissant._src.structure_graph.nodes.metadata import Metadata
from mlcroissant._src.structure_graph.nodes.record_set import RecordSet
from mlcroissant._src.tests.nodes import create_test_node
Expand Down Expand Up @@ -175,3 +176,18 @@ def test_validate_license():
]
with pytest.raises(ValidationError, match="License should be a list of str"):
Metadata(name="foo", license=42) # pytype: disable=wrong-arg-types


def test_predecessors_are_propagated():
field = Field(
id="records/name",
name="name",
data_types=constants.DataType.TEXT,
)
record_set = RecordSet(
id="records",
fields=[field],
data=[{"name": "train"}, {"name": "test"}],
)
Metadata(name="dummy", record_sets=[record_set])
assert field.predecessors == {record_set}

0 comments on commit 2cb765e

Please sign in to comment.