From c62a0380e9c81497288aca026d7092fcadca81b2 Mon Sep 17 00:00:00 2001 From: skshetry <18718008+skshetry@users.noreply.github.com> Date: Fri, 22 Mar 2024 18:07:16 +0545 Subject: [PATCH] templating: support on top-level params/metrics/plots/artifacts/datasets (#10359) * templating: support on top-level params/metrics/plots/artifacts/datasets * make resolving top-level ordered * prevent updating parametrized dataset * fix dataset tests * add tests * fix comment * support parametrizing artifacts name --- dvc/dvcfile.py | 36 +++--- dvc/parsing/__init__.py | 91 +++++++++++++- tests/func/artifacts/test_artifacts.py | 10 ++ tests/func/metrics/test_show.py | 9 ++ tests/func/params/test_show.py | 14 +++ tests/func/parsing/test_top_level.py | 168 +++++++++++++++++++++++++ tests/func/plots/test_show.py | 70 +++++++++++ tests/func/test_dataset.py | 60 +++++++++ 8 files changed, 442 insertions(+), 16 deletions(-) create mode 100644 tests/func/parsing/test_top_level.py diff --git a/dvc/dvcfile.py b/dvc/dvcfile.py index dde1aac8a5..1a150e8c9d 100644 --- a/dvc/dvcfile.py +++ b/dvc/dvcfile.py @@ -242,16 +242,22 @@ def dump(self, stage, update_pipeline=True, update_lock=True, **kwargs): def dump_dataset(self, dataset): with modify_yaml(self.path, fs=self.repo.fs) as data: - datasets: list[dict] = data.setdefault("datasets", []) + parsed = self.datasets if data else [] + raw = data.setdefault("datasets", []) loc = next( - (i for i, ds in enumerate(datasets) if ds["name"] == dataset["name"]), + (i for i, ds in enumerate(parsed) if ds["name"] == dataset["name"]), None, ) if loc is not None: - apply_diff(dataset, datasets[loc]) - datasets[loc] = dataset + if raw[loc] != parsed[loc]: + raise ParametrizedDumpError( + "cannot update a parametrized dataset entry" + ) + + apply_diff(dataset, raw[loc]) + raw[loc] = dataset else: - datasets.append(dataset) + raw.append(dataset) self.repo.scm_context.track_file(self.relpath) def _dump_lockfile(self, stage, **kwargs): @@ -307,29 +313,29 @@ def stages(self) -> LOADER: return self.LOADER(self, self.contents, self.lockfile_contents) @property - def metrics(self) -> list[str]: - return self.contents.get("metrics", []) + def artifacts(self) -> dict[str, Optional[dict[str, Any]]]: + return self.resolver.resolve_artifacts() @property - def plots(self) -> Any: - return self.contents.get("plots", {}) + def metrics(self) -> list[str]: + return self.resolver.resolve_metrics() @property def params(self) -> list[str]: - return self.contents.get("params", []) + return self.resolver.resolve_params() + + @property + def plots(self) -> list[Any]: + return self.resolver.resolve_plots() @property def datasets(self) -> list[dict[str, Any]]: - return self.contents.get("datasets", []) + return self.resolver.resolve_datasets() @property def datasets_lock(self) -> list[dict[str, Any]]: return self.lockfile_contents.get("datasets", []) - @property - def artifacts(self) -> dict[str, Optional[dict[str, Any]]]: - return self.contents.get("artifacts", {}) - def remove(self, force=False): if not force: logger.warning("Cannot remove pipeline file.") diff --git a/dvc/parsing/__init__.py b/dvc/parsing/__init__.py index 772d8620de..b6be7c5e21 100644 --- a/dvc/parsing/__init__.py +++ b/dvc/parsing/__init__.py @@ -21,6 +21,7 @@ VarsAlreadyLoaded, ) from .interpolate import ( + check_expression, check_recursive_parse_errors, is_interpolated_string, recurse, @@ -38,10 +39,16 @@ logger = logger.getChild(__name__) -STAGES_KWD = "stages" VARS_KWD = "vars" WDIR_KWD = "wdir" + +ARTIFACTS_KWD = "artifacts" +DATASETS_KWD = "datasets" +METRICS_KWD = "metrics" PARAMS_KWD = "params" +PLOTS_KWD = "plots" +STAGES_KWD = "stages" + FOREACH_KWD = "foreach" MATRIX_KWD = "matrix" DO_KWD = "do" @@ -163,6 +170,27 @@ def __init__(self, repo: "Repo", wdir: str, d: dict): for name, definition in stages_data.items() } + self.artifacts = [ + ArtifactDefinition(self, self.context, name, definition, ARTIFACTS_KWD) + for name, definition in d.get(ARTIFACTS_KWD, {}).items() + ] + self.datasets = [ + TopDefinition(self, self.context, str(i), definition, DATASETS_KWD) + for i, definition in enumerate(d.get(DATASETS_KWD, [])) + ] + self.metrics = [ + TopDefinition(self, self.context, str(i), definition, METRICS_KWD) + for i, definition in enumerate(d.get(METRICS_KWD, [])) + ] + self.params = [ + TopDefinition(self, self.context, str(i), definition, PARAMS_KWD) + for i, definition in enumerate(d.get(PARAMS_KWD, [])) + ] + self.plots = [ + TopDefinition(self, self.context, str(i), definition, PLOTS_KWD) + for i, definition in enumerate(d.get(PLOTS_KWD, [])) + ] + def resolve_one(self, name: str): group, key = split_group_name(name) @@ -186,6 +214,27 @@ def resolve(self): logger.trace("Resolved dvc.yaml:\n%s", data) return {STAGES_KWD: data} + # Top-level sections are eagerly evaluated, whereas stages are lazily evaluated, + # one-by-one. + + def resolve_artifacts(self) -> dict[str, Optional[dict[str, Any]]]: + d: dict[str, Optional[dict[str, Any]]] = {} + for item in self.artifacts: + d.update(item.resolve()) + return d + + def resolve_datasets(self) -> list[dict[str, Any]]: + return [item.resolve() for item in self.datasets] + + def resolve_metrics(self) -> list[str]: + return [item.resolve() for item in self.metrics] + + def resolve_params(self) -> list[str]: + return [item.resolve() for item in self.params] + + def resolve_plots(self) -> list[Any]: + return [item.resolve() for item in self.plots] + def has_key(self, key: str): return self._has_group_and_key(*split_group_name(key)) @@ -565,3 +614,43 @@ def _each_iter(self, key: str) -> "DictStrAny": return entry.resolve_stage(skip_checks=True) except ContextError as exc: format_and_raise(exc, f"stage '{generated}'", self.relpath) + + +class TopDefinition: + def __init__( + self, + resolver: DataResolver, + context: Context, + name: str, + definition: "Any", + where: str, + ): + self.resolver = resolver + self.context = context + self.name = name + self.definition = definition + self.where = where + self.relpath = self.resolver.relpath + + def resolve(self): + try: + check_recursive_parse_errors(self.definition) + return self.context.resolve(self.definition) + except (ParseError, ContextError) as exc: + format_and_raise(exc, f"'{self.where}.{self.name}'", self.relpath) + + +class ArtifactDefinition(TopDefinition): + def resolve(self) -> dict[str, Optional[dict[str, Any]]]: + try: + check_expression(self.name) + name = self.context.resolve(self.name) + if not isinstance(name, str): + typ = type(name).__name__ + raise ResolveError( + f"failed to resolve '{self.where}.{self.name}'" + f" in '{self.relpath}': expected str, got " + typ + ) + except (ParseError, ContextError) as exc: + format_and_raise(exc, f"'{self.where}.{self.name}'", self.relpath) + return {name: super().resolve()} diff --git a/tests/func/artifacts/test_artifacts.py b/tests/func/artifacts/test_artifacts.py index ea2553c36b..fcad4d3071 100644 --- a/tests/func/artifacts/test_artifacts.py +++ b/tests/func/artifacts/test_artifacts.py @@ -196,3 +196,13 @@ def test_get_path(tmp_dir, dvc): assert dvc.artifacts.get_path("subdir/dvc.yaml:myart") == os.path.join( "subdir", "myart.pkl" ) + + +def test_parametrized(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump({"path": "myart.pkl"}) + (tmp_dir / "dvc.yaml").dump( + {"artifacts": {"myart": {"type": "model", "path": "${path}"}}} + ) + assert tmp_dir.dvc.artifacts.read() == { + "dvc.yaml": {"myart": Artifact(path="myart.pkl", type="model")} + } diff --git a/tests/func/metrics/test_show.py b/tests/func/metrics/test_show.py index 9f653df6a0..d9ea003c28 100644 --- a/tests/func/metrics/test_show.py +++ b/tests/func/metrics/test_show.py @@ -342,3 +342,12 @@ def test_cached_metrics(tmp_dir, dvc, scm, remote): } } } + + +def test_top_level_parametrized(tmp_dir, dvc): + tmp_dir.dvc_gen("metrics.yaml", "foo: 3\nbar: 10") + (tmp_dir / "params.yaml").dump({"metric_file": "metrics.yaml"}) + (tmp_dir / "dvc.yaml").dump({"metrics": ["${metric_file}"]}) + assert dvc.metrics.show() == { + "": {"data": {"metrics.yaml": {"data": {"foo": 3, "bar": 10}}}} + } diff --git a/tests/func/params/test_show.py b/tests/func/params/test_show.py index f528f6da62..9b81545f8e 100644 --- a/tests/func/params/test_show.py +++ b/tests/func/params/test_show.py @@ -178,3 +178,17 @@ def test_cached_params(tmp_dir, dvc, scm, remote): } } } + + +def test_top_level_parametrized(tmp_dir, dvc): + (tmp_dir / "param.json").dump({"foo": 3, "bar": 10}) + (tmp_dir / "params.yaml").dump({"param_file": "param.json"}) + (tmp_dir / "dvc.yaml").dump({"params": ["${param_file}"]}) + assert dvc.params.show() == { + "": { + "data": { + "param.json": {"data": {"foo": 3, "bar": 10}}, + "params.yaml": {"data": {"param_file": "param.json"}}, + } + } + } diff --git a/tests/func/parsing/test_top_level.py b/tests/func/parsing/test_top_level.py new file mode 100644 index 0000000000..6f9419157c --- /dev/null +++ b/tests/func/parsing/test_top_level.py @@ -0,0 +1,168 @@ +from dvc.parsing import DataResolver + + +def test_params(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump( + {"params": {"param1": "params.json", "param2": "params.toml"}} + ) + + template = {"params": ["${params.param1}", "param11", "${params.param2}"]} + resolver = DataResolver(dvc, tmp_dir, template) + assert resolver.resolve_params() == ["params.json", "param11", "params.toml"] + + +def test_metrics(tmp_dir, dvc): + (tmp_dir / "params.yaml").dump( + {"metrics": {"metric1": "metrics.json", "metric2": "metrics.toml"}} + ) + + template = {"metrics": ["${metrics.metric1}", "metric11", "${metrics.metric2}"]} + resolver = DataResolver(dvc, tmp_dir, template) + assert resolver.resolve_metrics() == ["metrics.json", "metric11", "metrics.toml"] + + +def test_plots(tmp_dir, dvc): + template = { + "plots": [ + { + "${plots.plot1_name}": { + "x": "${plots.x_cls}", + "y": { + "train_classes.csv": "${plots.y_train_cls}", + "test_classes.csv": [ + "${plots.y_train_cls}", + "${plots.y_test_cls}", + ], + }, + "title": "Compare test vs train confusion matrix", + "template": "confusion", + "x_label": "Actual class", + "y_label": "Predicted class", + } + }, + {"eval/importance2.png": None}, + {"${plots.plot3_name}": None}, + "eval/importance4.png", + "${plots.plot5_name}", + ], + } + + (tmp_dir / "params.yaml").dump( + { + "plots": { + "x_cls": "actual_class", + "y_train_cls": "predicted_class", + "y_test_cls": "predicted_class2", + "plot1_name": "eval/importance1.png", + "plot3_name": "eval/importance3.png", + "plot5_name": "eval/importance5.png", + } + } + ) + resolver = DataResolver(dvc, tmp_dir, template) + assert resolver.resolve_plots() == [ + { + "eval/importance1.png": { + "x": "actual_class", + "y": { + "train_classes.csv": "predicted_class", + "test_classes.csv": ["predicted_class", "predicted_class2"], + }, + "title": "Compare test vs train confusion matrix", + "template": "confusion", + "x_label": "Actual class", + "y_label": "Predicted class", + } + }, + {"eval/importance2.png": None}, + {"eval/importance3.png": None}, + "eval/importance4.png", + "eval/importance5.png", + ] + + +def test_artifacts(tmp_dir, dvc): + template = { + "artifacts": { + "${artifacts.name}": { + "path": "${artifacts.path}", + "type": "model", + "desc": "CV classification model, ResNet50", + "labels": ["${artifacts.label1}", "${artifacts.label2}"], + "meta": {"framework": "${artifacts.framework}"}, + } + } + } + + (tmp_dir / "params.yaml").dump( + { + "artifacts": { + "name": "cv-classification", + "path": "models/resnet.pt", + "label1": "resnet50", + "label2": "classification", + "framework": "pytorch", + } + } + ) + + resolver = DataResolver(dvc, tmp_dir, template) + assert resolver.resolve_artifacts() == { + "cv-classification": { + "path": "models/resnet.pt", + "type": "model", + "desc": "CV classification model, ResNet50", + "labels": ["resnet50", "classification"], + "meta": {"framework": "pytorch"}, + } + } + + +def test_datasets(tmp_dir, dvc): + template = { + "datasets": [ + {"name": "${ds1.name}", "url": "${ds1.url}", "type": "dvcx"}, + { + "name": "${ds2.name}", + "url": "${ds2.url}", + "type": "dvc", + "path": "${ds2.path}", + }, + { + "name": "${ds3.name}", + "url": "${ds3.url}", + "type": "url", + }, + ] + } + + (tmp_dir / "params.yaml").dump( + { + "ds1": {"name": "dogs", "url": "dvcx://dogs"}, + "ds2": { + "name": "example-get-started", + "url": "git@github.com:iterative/example-get-started.git", + "path": "path", + }, + "ds3": { + "name": "cloud-versioning-demo", + "url": "s3://cloud-versioning-demo", + }, + } + ) + + resolver = DataResolver(dvc, tmp_dir, template) + assert resolver.resolve_datasets() == [ + {"name": "dogs", "url": "dvcx://dogs", "type": "dvcx"}, + { + "name": "example-get-started", + "url": "git@github.com:iterative/example-get-started.git", + "type": "dvc", + "path": "path", + }, + { + "name": "cloud-versioning-demo", + "url": "s3://cloud-versioning-demo", + "type": "url", + }, + ] diff --git a/tests/func/plots/test_show.py b/tests/func/plots/test_show.py index 526a390408..8cb9327e92 100644 --- a/tests/func/plots/test_show.py +++ b/tests/func/plots/test_show.py @@ -500,3 +500,73 @@ def test_show_plots_defined_with_native_os_path(tmp_dir, dvc, scm, capsys): json_data = json_out["data"] assert json_data[f"{top_level_plot}"] assert json_data[stage_plot] + + +@pytest.mark.parametrize( + "plot_config,expanded_config,expected_datafiles", + [ + ( + { + "comparison": { + "x": {"${data1}": "${a}"}, + "y": {"sub/dir/data2.json": "${b}"}, + } + }, + { + "comparison": { + "x": {"data1.json": "a"}, + "y": {"sub/dir/data2.json": "b"}, + } + }, + ["data1.json", os.path.join("sub", "dir", "data2.json")], + ), + ( + {"${data1}": None}, + {"data1.json": {}}, + ["data1.json"], + ), + ( + "${data1}", + {"data1.json": {}}, + ["data1.json"], + ), + ], +) +def test_top_level_parametrized( + tmp_dir, dvc, plot_config, expanded_config, expected_datafiles +): + (tmp_dir / "params.yaml").dump( + {"data1": "data1.json", "a": "a", "b": "b", "c": "c"} + ) + data = { + "data1.json": [ + {"a": 1, "b": 0.1, "c": 0.01}, + {"a": 2, "b": 0.2, "c": 0.02}, + ], + os.path.join("sub", "dir", "data.json"): [ + {"a": 6, "b": 0.6, "c": 0.06}, + {"a": 7, "b": 0.7, "c": 0.07}, + ], + } + + for filename, content in data.items(): + dirname = os.path.dirname(filename) + if dirname: + os.makedirs(dirname) + (tmp_dir / filename).dump_json(content, sort_keys=True) + + config_file = "dvc.yaml" + with modify_yaml(config_file) as dvcfile_content: + dvcfile_content["plots"] = [plot_config] + + result = dvc.plots.show() + + assert expanded_config == get_plot( + result, "workspace", typ="definitions", file=config_file + ) + + for filename, content in data.items(): + if filename in expected_datafiles: + assert content == get_plot(result, "workspace", file=filename) + else: + assert filename not in get_plot(result, "workspace") diff --git a/tests/func/test_dataset.py b/tests/func/test_dataset.py index 77a83e1fe2..b6a15aff62 100644 --- a/tests/func/test_dataset.py +++ b/tests/func/test_dataset.py @@ -412,3 +412,63 @@ def test_collect(tmp_dir, dvc): with pytest.raises(DatasetNotFoundError, match=r"^dataset not found$"): dvc.datasets["not-existing"] + + +def test_parametrized(tmp_dir, dvc): + (tmp_dir / "dvc.yaml").dump( + { + "datasets": [ + {"name": "${ds1.name}", "url": "${ds1.url}", "type": "dvcx"}, + { + "name": "${ds2.name}", + "url": "${ds2.url}", + "type": "dvc", + "path": "${ds2.path}", + }, + { + "name": "${ds3.name}", + "url": "${ds3.url}", + "type": "url", + }, + ] + } + ) + (tmp_dir / "params.yaml").dump( + { + "ds1": {"name": "dogs", "url": "dvcx://dogs"}, + "ds2": { + "name": "example-get-started", + "url": "git@github.com:iterative/example-get-started.git", + "path": "path", + }, + "ds3": { + "name": "cloud-versioning-demo", + "url": "s3://cloud-versioning-demo", + }, + } + ) + + path = (tmp_dir / "dvc.yaml").fs_path + assert dict(dvc.datasets.items()) == { + "dogs": DVCXDataset( + manifest_path=path, + spec=DatasetSpec(name="dogs", url="dvcx://dogs", type="dvcx"), + ), + "example-get-started": DVCDataset( + manifest_path=path, + spec=DVCDatasetSpec( + name="example-get-started", + url="git@github.com:iterative/example-get-started.git", + path="path", + type="dvc", + ), + ), + "cloud-versioning-demo": URLDataset( + manifest_path=path, + spec=DatasetSpec( + name="cloud-versioning-demo", + url="s3://cloud-versioning-demo", + type="url", + ), + ), + }