Skip to content

Commit

Permalink
templating: support on top-level params/metrics/plots/artifacts/datas…
Browse files Browse the repository at this point in the history
…ets (#10359)

* templating: support on top-level params/metrics/plots/artifacts/datasets

* make resolving top-level ordered

* prevent updating parametrized dataset

* fix dataset tests

* add tests

* fix comment

* support parametrizing artifacts name
  • Loading branch information
skshetry authored Mar 22, 2024
1 parent 8d8939f commit c62a038
Show file tree
Hide file tree
Showing 8 changed files with 442 additions and 16 deletions.
36 changes: 21 additions & 15 deletions dvc/dvcfile.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,16 +242,22 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs):

def dump_dataset(self, dataset):
with modify_yaml(self.path, fs=self.repo.fs) as data:
datasets: list[dict] = data.setdefault("datasets", [])
parsed = self.datasets if data else []
raw = data.setdefault("datasets", [])
loc = next(
(i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]),
(i for i, ds in enumerate(parsed) if ds["name"] == dataset["name"]),
None,
)
if loc is not None:
apply_diff(dataset, datasets[loc])
datasets[loc] = dataset
if raw[loc] != parsed[loc]:
raise ParametrizedDumpError(
"cannot update a parametrized dataset entry"
)

apply_diff(dataset, raw[loc])
raw[loc] = dataset
else:
datasets.append(dataset)
raw.append(dataset)
self.repo.scm_context.track_file(self.relpath)

def _dump_lockfile(self, stage, **kwargs):
Expand Down Expand Up @@ -307,29 +313,29 @@ def stages(self) -> LOADER:
return self.LOADER(self, self.contents, self.lockfile_contents)

@property
def metrics(self) -> list[str]:
return self.contents.get("metrics", [])
def artifacts(self) -> dict[str, Optional[dict[str, Any]]]:
return self.resolver.resolve_artifacts()

@property
def plots(self) -> Any:
return self.contents.get("plots", {})
def metrics(self) -> list[str]:
return self.resolver.resolve_metrics()

@property
def params(self) -> list[str]:
return self.contents.get("params", [])
return self.resolver.resolve_params()

@property
def plots(self) -> list[Any]:
return self.resolver.resolve_plots()

@property
def datasets(self) -> list[dict[str, Any]]:
return self.contents.get("datasets", [])
return self.resolver.resolve_datasets()

@property
def datasets_lock(self) -> list[dict[str, Any]]:
return self.lockfile_contents.get("datasets", [])

@property
def artifacts(self) -> dict[str, Optional[dict[str, Any]]]:
return self.contents.get("artifacts", {})

def remove(self, force=False):
if not force:
logger.warning("Cannot remove pipeline file.")
Expand Down
91 changes: 90 additions & 1 deletion dvc/parsing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
VarsAlreadyLoaded,
)
from .interpolate import (
check_expression,
check_recursive_parse_errors,
is_interpolated_string,
recurse,
Expand All @@ -38,10 +39,16 @@

logger = logger.getChild(__name__)

STAGES_KWD = "stages"
VARS_KWD = "vars"
WDIR_KWD = "wdir"

ARTIFACTS_KWD = "artifacts"
DATASETS_KWD = "datasets"
METRICS_KWD = "metrics"
PARAMS_KWD = "params"
PLOTS_KWD = "plots"
STAGES_KWD = "stages"

FOREACH_KWD = "foreach"
MATRIX_KWD = "matrix"
DO_KWD = "do"
Expand Down Expand Up @@ -163,6 +170,27 @@ def __init__(self, repo: "Repo", wdir: str, d: dict):
for name, definition in stages_data.items()
}

self.artifacts = [
ArtifactDefinition(self, self.context, name, definition, ARTIFACTS_KWD)
for name, definition in d.get(ARTIFACTS_KWD, {}).items()
]
self.datasets = [
TopDefinition(self, self.context, str(i), definition, DATASETS_KWD)
for i, definition in enumerate(d.get(DATASETS_KWD, []))
]
self.metrics = [
TopDefinition(self, self.context, str(i), definition, METRICS_KWD)
for i, definition in enumerate(d.get(METRICS_KWD, []))
]
self.params = [
TopDefinition(self, self.context, str(i), definition, PARAMS_KWD)
for i, definition in enumerate(d.get(PARAMS_KWD, []))
]
self.plots = [
TopDefinition(self, self.context, str(i), definition, PLOTS_KWD)
for i, definition in enumerate(d.get(PLOTS_KWD, []))
]

def resolve_one(self, name: str):
group, key = split_group_name(name)

Expand All @@ -186,6 +214,27 @@ def resolve(self):
logger.trace("Resolved dvc.yaml:\n%s", data)
return {STAGES_KWD: data}

# Top-level sections are eagerly evaluated, whereas stages are lazily evaluated,
# one-by-one.

def resolve_artifacts(self) -> dict[str, Optional[dict[str, Any]]]:
d: dict[str, Optional[dict[str, Any]]] = {}
for item in self.artifacts:
d.update(item.resolve())
return d

def resolve_datasets(self) -> list[dict[str, Any]]:
return [item.resolve() for item in self.datasets]

def resolve_metrics(self) -> list[str]:
return [item.resolve() for item in self.metrics]

def resolve_params(self) -> list[str]:
return [item.resolve() for item in self.params]

def resolve_plots(self) -> list[Any]:
return [item.resolve() for item in self.plots]

def has_key(self, key: str):
return self._has_group_and_key(*split_group_name(key))

Expand Down Expand Up @@ -565,3 +614,43 @@ def _each_iter(self, key: str) -> "DictStrAny":
return entry.resolve_stage(skip_checks=True)
except ContextError as exc:
format_and_raise(exc, f"stage '{generated}'", self.relpath)


class TopDefinition:
def __init__(
self,
resolver: DataResolver,
context: Context,
name: str,
definition: "Any",
where: str,
):
self.resolver = resolver
self.context = context
self.name = name
self.definition = definition
self.where = where
self.relpath = self.resolver.relpath

def resolve(self):
try:
check_recursive_parse_errors(self.definition)
return self.context.resolve(self.definition)
except (ParseError, ContextError) as exc:
format_and_raise(exc, f"'{self.where}.{self.name}'", self.relpath)


class ArtifactDefinition(TopDefinition):
def resolve(self) -> dict[str, Optional[dict[str, Any]]]:
try:
check_expression(self.name)
name = self.context.resolve(self.name)
if not isinstance(name, str):
typ = type(name).__name__
raise ResolveError(
f"failed to resolve '{self.where}.{self.name}'"
f" in '{self.relpath}': expected str, got " + typ
)
except (ParseError, ContextError) as exc:
format_and_raise(exc, f"'{self.where}.{self.name}'", self.relpath)
return {name: super().resolve()}
10 changes: 10 additions & 0 deletions tests/func/artifacts/test_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -196,3 +196,13 @@ def test_get_path(tmp_dir, dvc):
assert dvc.artifacts.get_path("subdir/dvc.yaml:myart") == os.path.join(
"subdir", "myart.pkl"
)


def test_parametrized(tmp_dir, dvc):
(tmp_dir / "params.yaml").dump({"path": "myart.pkl"})
(tmp_dir / "dvc.yaml").dump(
{"artifacts": {"myart": {"type": "model", "path": "${path}"}}}
)
assert tmp_dir.dvc.artifacts.read() == {
"dvc.yaml": {"myart": Artifact(path="myart.pkl", type="model")}
}
9 changes: 9 additions & 0 deletions tests/func/metrics/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,3 +342,12 @@ def test_cached_metrics(tmp_dir, dvc, scm, remote):
}
}
}


def test_top_level_parametrized(tmp_dir, dvc):
tmp_dir.dvc_gen("metrics.yaml", "foo: 3\nbar: 10")
(tmp_dir / "params.yaml").dump({"metric_file": "metrics.yaml"})
(tmp_dir / "dvc.yaml").dump({"metrics": ["${metric_file}"]})
assert dvc.metrics.show() == {
"": {"data": {"metrics.yaml": {"data": {"foo": 3, "bar": 10}}}}
}
14 changes: 14 additions & 0 deletions tests/func/params/test_show.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,3 +178,17 @@ def test_cached_params(tmp_dir, dvc, scm, remote):
}
}
}


def test_top_level_parametrized(tmp_dir, dvc):
(tmp_dir / "param.json").dump({"foo": 3, "bar": 10})
(tmp_dir / "params.yaml").dump({"param_file": "param.json"})
(tmp_dir / "dvc.yaml").dump({"params": ["${param_file}"]})
assert dvc.params.show() == {
"": {
"data": {
"param.json": {"data": {"foo": 3, "bar": 10}},
"params.yaml": {"data": {"param_file": "param.json"}},
}
}
}
Loading

0 comments on commit c62a038

Please sign in to comment.