diff --git a/dvc/cli/parser.py b/dvc/cli/parser.py index 5ff5199af6..88a60d1c2a 100644 --- a/dvc/cli/parser.py +++ b/dvc/cli/parser.py @@ -26,6 +26,7 @@ get_url, git_hook, imp, + imp_db, imp_url, init, install, @@ -65,6 +66,7 @@ data_sync, gc, imp, + imp_db, imp_url, config, checkout, diff --git a/dvc/commands/imp_db.py b/dvc/commands/imp_db.py new file mode 100644 index 0000000000..97faf430ec --- /dev/null +++ b/dvc/commands/imp_db.py @@ -0,0 +1,90 @@ +import argparse + +from dvc.cli import completion +from dvc.cli.command import CmdBase +from dvc.cli.utils import append_doc_link +from dvc.log import logger + +logger = logger.getChild(__name__) + + +class CmdImportDb(CmdBase): + def run(self): + if not (self.args.sql or self.args.model): + raise argparse.ArgumentTypeError("Either of --sql or --model is required.") + + self.repo.imp_db( + url=self.args.url, + rev=self.args.rev, + project_dir=self.args.project_dir, + sql=self.args.sql, + model=self.args.model, + profile=self.args.profile, + target=self.args.target, + output_format=self.args.output_format, + out=self.args.out, + force=self.args.force, + ) + return 0 + + +def add_parser(subparsers, parent_parser): + IMPORT_HELP = ( + "Download file or directory tracked by DVC or by Git " + "into the workspace, and track it." + ) + + import_parser = subparsers.add_parser( + "import-db", + parents=[parent_parser], + description=append_doc_link(IMPORT_HELP, "import"), + add_help=False, + ) + import_parser.add_argument( + "--url", help="Location of DVC or Git repository to download from" + ) + import_parser.add_argument( + "--rev", + nargs="?", + help="Git revision (e.g. SHA, branch, tag)", + metavar="", + ) + import_parser.add_argument( + "--project-dir", nargs="?", help="Subdirectory to the dbt project location" + ) + + group = import_parser.add_mutually_exclusive_group() + group.add_argument( + "--sql", + help="SQL query", + ) + group.add_argument( + "--model", + help="Model name to download", + ) + import_parser.add_argument("--profile", help="Profile to use") + import_parser.add_argument("--target", help="Target to use") + import_parser.add_argument( + "--output-format", + default="csv", + const="csv", + nargs="?", + choices=["csv", "json"], + help="Export format", + ) + import_parser.add_argument( + "-o", + "--out", + nargs="?", + help="Destination path to download files to", + metavar="", + ).complete = completion.FILE + import_parser.add_argument( + "-f", + "--force", + action="store_true", + default=False, + help="Override destination file or folder if exists.", + ) + + import_parser.set_defaults(func=CmdImportDb) diff --git a/dvc/config_schema.py b/dvc/config_schema.py index 8677c0471f..16d3e9af02 100644 --- a/dvc/config_schema.py +++ b/dvc/config_schema.py @@ -318,6 +318,8 @@ def __call__(self, data): "feature": FeatureSchema( { Optional("machine", default=False): Bool, + "db_profile": str, + "db_target": str, }, ), "plots": { diff --git a/dvc/dependency/__init__.py b/dvc/dependency/__init__.py index ce800be7bf..c96ef5f93f 100644 --- a/dvc/dependency/__init__.py +++ b/dvc/dependency/__init__.py @@ -4,6 +4,7 @@ from dvc.output import ARTIFACT_SCHEMA, DIR_FILES_SCHEMA, Output from .base import Dependency +from .db import DB_SCHEMA, PARAM_DB, DbDependency, DbtDependency from .param import ParamsDependency from .repo import RepoDependency @@ -14,20 +15,28 @@ SCHEMA: Mapping[str, Any] = { **ARTIFACT_SCHEMA, **RepoDependency.REPO_SCHEMA, + **DB_SCHEMA, Output.PARAM_FILES: [DIR_FILES_SCHEMA], Output.PARAM_FS_CONFIG: dict, } def _get(stage, p, info, **kwargs): - if info and info.get(RepoDependency.PARAM_REPO): - repo = info.pop(RepoDependency.PARAM_REPO) - return RepoDependency(repo, stage, p, info) + d = info or {} + db = d.get(PARAM_DB, {}) + params = d.pop(ParamsDependency.PARAM_PARAMS, None) + repo = d.pop(RepoDependency.PARAM_REPO, None) - if info and info.get(ParamsDependency.PARAM_PARAMS): - params = info.pop(ParamsDependency.PARAM_PARAMS) + if params: return ParamsDependency(stage, p, params) + if DbDependency.PARAM_QUERY in db: + return DbDependency(stage, info) + if db: + return DbtDependency(repo, stage, info) + assert p + if repo: + return RepoDependency(repo, stage, p, info) return Dependency(stage, p, info, **kwargs) @@ -44,9 +53,11 @@ def loadd_from(stage, d_list): return ret -def loads_from(stage, s_list, erepo=None, fs_config=None): +def loads_from(stage, s_list, erepo=None, fs_config=None, db=None): assert isinstance(s_list, list) info = {RepoDependency.PARAM_REPO: erepo} if erepo else {} + if db: + info.update({"db": db}) return [ _get( stage, diff --git a/dvc/dependency/db.py b/dvc/dependency/db.py new file mode 100644 index 0000000000..d00f1213fc --- /dev/null +++ b/dvc/dependency/db.py @@ -0,0 +1,288 @@ +import os +from contextlib import contextmanager +from typing import TYPE_CHECKING, Any, Dict, Iterator, Union + +from funcy import compact + +from dvc.exceptions import DvcException +from dvc.log import logger +from dvc.scm import SCM + +from .base import Dependency + +if TYPE_CHECKING: + from agate import Table + from rich.status import Status + + from dvc.stage import Stage + +logger = logger.getChild(__name__) + + +PARAM_DB = "db" +PARAM_PROFILE = "profile" +PARAM_FILE_FORMAT = "file_format" + + +def _get_db_config(config: Dict) -> Dict: + conf = config.get("feature", {}) + pref = "db_" + return {k.lstrip(pref): v for k, v in conf.items() if k.startswith(pref)} + + +@contextmanager +def log_status(msg, log=logger.debug) -> Iterator["Status"]: + from funcy import log_durations + + from dvc.ui import ui + + with log_durations(log, msg), ui.status(msg) as status: + yield status + + +@contextmanager +def chdir(path): + wdir = os.getcwd() + os.chdir(path) + try: + yield + finally: + os.chdir(wdir) + + +def export_to(table: "Table", to: str, file_format: str = "csv") -> None: + exporter = {"csv": table.to_csv, "json": table.to_json} + return exporter[file_format](to) + + +class AbstractDependency(Dependency): + """Dependency without workspace/fs/fs_path""" + + def __init__(self, stage: "Stage", info, *args, **kwargs): + self.repo = stage.repo + self.stage = stage + self.fs = None + self.fs_path = None + self.def_path = None # type: ignore[assignment] + self.info = info or {} + + +class DbDependency(AbstractDependency): + PARAM_QUERY = "query" + QUERY_SCHEMA = {PARAM_QUERY: str} + + def __init__(self, stage: "Stage", info, *args, **kwargs): + super().__init__(stage, info, *args, **kwargs) + self.target = None + + def __repr__(self): + return "{}:{}".format( + self.__class__.__name__, "".join(f"{k}=={v}" for k, v in self.info.items()) + ) + + def __str__(self): + from dvc.utils.humanize import truncate_text + + db_info = self.info.get(PARAM_DB, {}) + query = db_info.get(self.PARAM_QUERY, "[query]") + return truncate_text(query, 50) + + def workspace_status(self): + return False # no workspace to check + + def status(self): + return self.workspace_status() + + def save(self): + """nothing to save.""" + + def dumpd(self, **kwargs): + db_info = compact(self.info.get(PARAM_DB, {})) + return {PARAM_DB: db_info} if db_info else {} + + def update(self, rev=None): + """nothing to update.""" + + def download(self, to, jobs=None, file_format=None): # noqa: ARG002 + db_info = self.info.get(PARAM_DB, {}) + query = db_info.get(self.PARAM_QUERY) + if not query: + raise DvcException("Cannot download: no query specified") + + from dvc.utils.db import _check_dbt, _profiles_dir, execute_sql + + db_config = _get_db_config(self.repo.config) + profile = db_info.get(PARAM_PROFILE) or db_config.get(PARAM_PROFILE) + target = self.target or db_config.get("target") + file_format = file_format or db_info.get(PARAM_FILE_FORMAT, "csv") + + _check_dbt(self.PARAM_QUERY) + profiles_dir = _profiles_dir(self.repo.root_dir) + with log_status("Executing query") as status: + table = execute_sql( + query, + profiles_dir, + self.repo.root_dir, + profile, + target=target, + status=status, + ) + # NOTE: we keep everything in memory, and then export it out later. + with log_status(f"Saving to {to}"): + return export_to(table, to.fs_path, file_format) + + +class DbtDependency(AbstractDependency): + PARAM_MODEL = "model" + PARAM_VERSION = "version" + PARAM_PROJECT_DIR = "project_dir" + DBT_SCHEMA = { + PARAM_MODEL: str, + PARAM_VERSION: str, + PARAM_PROJECT_DIR: str, + } + + def __init__(self, def_repo: Dict[str, Any], stage: "Stage", info, *args, **kwargs): + self.def_repo = def_repo or {} + self.target = None + super().__init__(stage, info, *args, **kwargs) + + def __repr__(self): + return "{}:{}".format( + self.__class__.__name__, + "".join(f"{k}=={v}" for k, v in {**self.def_repo, **self.info}.items()), + ) + + def __str__(self): + from .repo import RepoDependency + + repo = self.def_repo.get(RepoDependency.PARAM_URL) + rev = self.def_repo.get(RepoDependency.PARAM_REV) + + db_info = self.info.get(PARAM_DB, {}) + db = db_info.get(self.PARAM_MODEL, "") + project_dir = db_info.get(self.PARAM_PROJECT_DIR, "") + repo_info = "" + if repo: + repo_info += repo + if rev: + repo_info += f"@{rev}" + if project_dir: + repo_info += f":/{project_dir}" + return db + (f"({repo_info})" if repo_info else "") + + @property + def locked_rev(self): + from .repo import RepoDependency + + return self.def_repo.get(RepoDependency.PARAM_REV_LOCK) + + @property + def rev(self): + from .repo import RepoDependency + + return self.def_repo.get(RepoDependency.PARAM_REV) + + def workspace_status(self): + if not self.def_repo: + return + + current = self._get_clone(self.locked_rev or self.rev).get_rev() + updated = self._get_clone(self.rev).get_rev() + if current != updated: + return {str(self): "update available"} + return {} + + def status(self): + return self.workspace_status() + + def save(self): + from .repo import RepoDependency + + if not self.def_repo: + return + + rev = self._get_clone(self.locked_rev or self.rev).get_rev() + if self.def_repo.get(RepoDependency.PARAM_REV_LOCK) is None: + self.def_repo[RepoDependency.PARAM_REV_LOCK] = rev + + def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]: + from .repo import RepoDependency + + def_repo = {} + if self.def_repo: + def_repo = RepoDependency._dump_def_repo(self.def_repo) + + db_info = compact(self.info.get(PARAM_DB, {})) + return compact({RepoDependency.PARAM_REPO: def_repo, PARAM_DB: db_info}) + + def _get_clone(self, rev): + from dvc.repo.open_repo import _cached_clone + + from .repo import RepoDependency + + url = self.def_repo.get(RepoDependency.PARAM_URL) + repo_root = self.repo.root_dir if self.repo else os.getcwd() + return SCM(_cached_clone(url, rev) if url else repo_root) + + def update(self, rev=None): + from .repo import RepoDependency + + if not self.def_repo: + return + + if rev: + self.def_repo[RepoDependency.PARAM_REV] = rev + else: + rev = self.rev + self.def_repo[RepoDependency.PARAM_REV_LOCK] = self._get_clone(rev).get_rev() + + def download(self, to, jobs=None, file_format=None): # noqa: ARG002 + from dvc.ui import ui + + from .repo import RepoDependency + + project_dir = self.info.get(PARAM_DB, {}).get(self.PARAM_PROJECT_DIR, "") + if self.def_repo: + repo = self._get_clone(self.locked_rev or self.rev) + self.def_repo[RepoDependency.PARAM_REV_LOCK] = repo.get_rev() + root = wdir = repo.root_dir + else: + root = self.repo.root_dir + wdir = self.stage.wdir + + project_path = os.path.join(wdir, project_dir) if project_dir else root + + with chdir(project_path): + self._download_db(to, file_format=file_format) + ui.write(f"Saved file to {to}", styled=True) + + def _download_db(self, to, version=None, file_format=None): + from dvc.utils.db import get_model + + db_info = self.info.get(PARAM_DB, {}) + model = db_info.get(self.PARAM_MODEL) + if not model: + raise DvcException("Cannot download, no model specified") + + db_config = _get_db_config(self.repo.config) + version = version or db_info.get(self.PARAM_VERSION) + profile = db_info.get(PARAM_PROFILE) or db_config.get(PARAM_PROFILE) + target = self.target or db_info.get("target") + file_format = file_format or db_info.get(PARAM_FILE_FORMAT, "csv") + + with log_status("Downloading model"): + table = get_model(model, version=version, profile=profile, target=target) + # NOTE: we keep everything in memory, and then export it out later. + with log_status(f"Saving to {to}"): + export_to(table, to.fs_path, file_format=file_format) + + +DB_SCHEMA = { + PARAM_DB: { + PARAM_PROFILE: str, + PARAM_FILE_FORMAT: str, + **DbDependency.QUERY_SCHEMA, + **DbtDependency.DBT_SCHEMA, + }, +} diff --git a/dvc/dependency/repo.py b/dvc/dependency/repo.py index 1677a08721..d84b4ed9d7 100644 --- a/dvc/dependency/repo.py +++ b/dvc/dependency/repo.py @@ -64,28 +64,31 @@ def save(self): if self.def_repo.get(self.PARAM_REV_LOCK) is None: self.def_repo[self.PARAM_REV_LOCK] = rev - def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]: - repo = {self.PARAM_URL: self.def_repo[self.PARAM_URL]} + @classmethod + def _dump_def_repo(cls, def_repo) -> Dict[str, str]: + repo = {cls.PARAM_URL: def_repo[cls.PARAM_URL]} - rev = self.def_repo.get(self.PARAM_REV) + rev = def_repo.get(cls.PARAM_REV) if rev: - repo[self.PARAM_REV] = self.def_repo[self.PARAM_REV] + repo[cls.PARAM_REV] = def_repo[cls.PARAM_REV] - rev_lock = self.def_repo.get(self.PARAM_REV_LOCK) + rev_lock = def_repo.get(cls.PARAM_REV_LOCK) if rev_lock: - repo[self.PARAM_REV_LOCK] = rev_lock + repo[cls.PARAM_REV_LOCK] = rev_lock - config = self.def_repo.get(self.PARAM_CONFIG) + config = def_repo.get(cls.PARAM_CONFIG) if config: - repo[self.PARAM_CONFIG] = config + repo[cls.PARAM_CONFIG] = config - remote = self.def_repo.get(self.PARAM_REMOTE) + remote = def_repo.get(cls.PARAM_REMOTE) if remote: - repo[self.PARAM_REMOTE] = remote + repo[cls.PARAM_REMOTE] = remote + return repo + def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]: return { self.PARAM_PATH: self.def_path, - self.PARAM_REPO: repo, + self.PARAM_REPO: self._dump_def_repo(self.def_repo), } def update(self, rev: Optional[str] = None): diff --git a/dvc/output.py b/dvc/output.py index 6fbb2bab81..a7278ce690 100644 --- a/dvc/output.py +++ b/dvc/output.py @@ -1523,7 +1523,7 @@ def _merge_dir_version_meta(self, other: "Output"): ARTIFACT_SCHEMA: Dict[Any, Any] = { **CHECKSUMS_SCHEMA, **META_SCHEMA, - vol.Required(Output.PARAM_PATH): str, + Output.PARAM_PATH: str, Output.PARAM_PERSIST: bool, Output.PARAM_CLOUD: CLOUD_SCHEMA, Output.PARAM_HASH: str, diff --git a/dvc/repo/__init__.py b/dvc/repo/__init__.py index 39330e8efb..804d5a25b3 100644 --- a/dvc/repo/__init__.py +++ b/dvc/repo/__init__.py @@ -76,6 +76,7 @@ class Repo: from dvc.repo.get import get as _get # type: ignore[misc] from dvc.repo.get_url import get_url as _get_url # type: ignore[misc] from dvc.repo.imp import imp # type: ignore[misc] + from dvc.repo.imp_db import imp_db # type: ignore[misc] from dvc.repo.imp_url import imp_url # type: ignore[misc] from dvc.repo.install import install # type: ignore[misc] from dvc.repo.ls import ls as _ls # type: ignore[misc] diff --git a/dvc/repo/graph.py b/dvc/repo/graph.py index be7a255591..e285e1bbe5 100644 --- a/dvc/repo/graph.py +++ b/dvc/repo/graph.py @@ -135,6 +135,8 @@ def build_graph(stages, outs_trie=None): for stage in stages: if stage.is_repo_import: continue + if stage.is_db_import: + continue for dep in stage.deps: dep_key = dep.fs.path.parts(dep.fs_path) @@ -160,6 +162,8 @@ def build_outs_graph(graph, outs_trie): for stage in graph.nodes(): if stage.is_repo_import: continue + if stage.is_db_import: + continue for dep in stage.deps: dep_key = dep.fs.path.parts(dep.fs_path) overlapping = [n.value for n in outs_trie.prefixes(dep_key)] diff --git a/dvc/repo/imp_db.py b/dvc/repo/imp_db.py new file mode 100644 index 0000000000..2de37b60db --- /dev/null +++ b/dvc/repo/imp_db.py @@ -0,0 +1,80 @@ +from typing import TYPE_CHECKING, Any, Dict, Optional + +from dvc.exceptions import OutputDuplicationError +from dvc.repo.scm_context import scm_context +from dvc.ui import ui +from dvc.utils import resolve_output, resolve_paths + +if TYPE_CHECKING: + from . import Repo + +from . import locked + + +@locked +@scm_context +def imp_db( # noqa: PLR0913 + self: "Repo", + url: Optional[str] = None, + rev: Optional[str] = None, + project_dir: Optional[str] = None, + sql: Optional[str] = None, + model: Optional[str] = None, + version: Optional[int] = None, + frozen: bool = True, + profile: Optional[str] = None, + target: Optional[str] = None, + output_format: str = "csv", + out: Optional[str] = None, + force: bool = False, +): + ui.warn("WARNING: import-db is an experimental feature.") + ui.warn( + "The functionality may change or break without notice, " + "which could lead to unexpected behavior." + ) + erepo = None + if model and url: + erepo = {"url": url} + if rev: + erepo["rev"] = rev + + assert output_format in ("csv", "json") + + db: Dict[str, Any] = {"file_format": output_format} + if profile: + db["profile"] = profile + + if model: + out = out or f"{model}.{output_format}" + db.update({"model": model, "version": version, "project_dir": project_dir}) + else: + out = out or "results.csv" + db["query"] = sql + + out = resolve_output(url or ".", out, force=force) + path, wdir, out = resolve_paths(self, out, always_local=True) + stage = self.stage.create( + single_stage=True, + validate=False, + fname=path, + wdir=wdir, + deps=[url], + outs=[out], + erepo=erepo, + fs_config=None, + db=db, + ) + + stage.deps[0].target = target + try: + self.check_graph(stages={stage}) + except OutputDuplicationError as exc: + raise OutputDuplicationError( # noqa: B904 + exc.output, set(exc.stages) - {stage} + ) + + stage.run() + stage.frozen = frozen + stage.dump() + return stage diff --git a/dvc/repo/index.py b/dvc/repo/index.py index 91588c6ff5..d49cfa90bf 100644 --- a/dvc/repo/index.py +++ b/dvc/repo/index.py @@ -135,7 +135,11 @@ def _load_data_from_outs(index, prefix, outs): hash_info=out.hash_info, ) - if out.stage.is_import and not out.stage.is_repo_import: + if ( + out.stage.is_import + and not out.stage.is_repo_import + and not out.stage.is_db_import + ): dep = out.stage.deps[0] entry.meta = dep.meta if out.hash_info: @@ -183,6 +187,9 @@ def _load_storage_from_out(storage_map, key, out): except NoRemoteError: pass + if out.stage.is_db_import: + return + if out.stage.is_import: dep = out.stage.deps[0] if not out.hash_info: diff --git a/dvc/stage/__init__.py b/dvc/stage/__init__.py index 7cec7da9c9..ade80ad6ac 100644 --- a/dvc/stage/__init__.py +++ b/dvc/stage/__init__.py @@ -105,7 +105,7 @@ def create_stage(cls: Type[_T], repo, path, **kwargs) -> _T: fill_stage_outputs(stage, **kwargs) check_no_externals(stage) fill_stage_dependencies( - stage, **project(kwargs, ["deps", "erepo", "params", "fs_config"]) + stage, **project(kwargs, ["deps", "erepo", "params", "fs_config", "db"]) ) check_circular_dependency(stage) check_duplicated_arguments(stage) @@ -281,9 +281,24 @@ def is_repo_import(self) -> bool: return isinstance(self.deps[0], RepoDependency) + @property + def is_db_import(self) -> bool: + if not self.is_import: + return False + + from dvc.dependency import DbDependency, DbtDependency + + return isinstance(self.deps[0], (DbDependency, DbtDependency)) + @property def is_versioned_import(self) -> bool: - return self.is_import and self.deps[0].fs.version_aware + from dvc.dependency import DbDependency, DbtDependency + + return ( + self.is_import + and not isinstance(self.deps[0], (DbDependency, DbtDependency)) + and self.deps[0].fs.version_aware + ) def short_description(self) -> Optional["str"]: desc: Optional["str"] = None @@ -446,6 +461,9 @@ def update( ) -> None: if not (self.is_repo_import or self.is_import): raise StageUpdateError(self.relpath) + + # always force update DbDep/DbtDep since we don't know if it's changed + force = self.is_db_import update_import( self, rev=rev, @@ -453,6 +471,7 @@ def update( remote=remote, no_download=no_download, jobs=jobs, + force=force, ) def reload(self) -> "Stage": diff --git a/dvc/stage/decorators.py b/dvc/stage/decorators.py index 1d6672a208..2b15e0fd61 100644 --- a/dvc/stage/decorators.py +++ b/dvc/stage/decorators.py @@ -7,6 +7,7 @@ def rwlocked(call, read=None, write=None): import sys + from dvc.dependency.db import AbstractDependency from dvc.dependency.repo import RepoDependency from dvc.rwlock import rwlock @@ -27,7 +28,7 @@ def _chain(names): for item in getattr(stage, attr) # There is no need to lock RepoDependency deps, as there is no # corresponding OutputREPO, so we can't even write it. - if not isinstance(item, RepoDependency) + if not isinstance(item, (RepoDependency, AbstractDependency)) ] cmd = " ".join(sys.argv) diff --git a/dvc/stage/imports.py b/dvc/stage/imports.py index 391be83626..3b3c5e07a0 100644 --- a/dvc/stage/imports.py +++ b/dvc/stage/imports.py @@ -18,7 +18,13 @@ def _update_import_on_remote(stage, remote, jobs): def update_import( - stage, rev=None, to_remote=False, remote=None, no_download=None, jobs=None + stage, + rev=None, + to_remote=False, + remote=None, + no_download=None, + jobs=None, + force=False, ): stage.deps[0].update(rev=rev) @@ -30,7 +36,7 @@ def update_import( if to_remote: _update_import_on_remote(stage, remote, jobs) else: - stage.reproduce(no_download=no_download, jobs=jobs) + stage.reproduce(no_download=no_download, jobs=jobs, force=force) finally: if no_download and changed: # Avoid retaining stale information @@ -55,7 +61,7 @@ def sync_import( else: stage.save_deps() if no_download: - if stage.is_repo_import: + if stage.is_repo_import or stage.is_db_import: stage.deps[0].update() else: stage.deps[0].download( diff --git a/dvc/stage/utils.py b/dvc/stage/utils.py index 9ae7d62742..2a294a76a2 100644 --- a/dvc/stage/utils.py +++ b/dvc/stage/utils.py @@ -77,12 +77,14 @@ def fill_stage_outputs(stage, **kwargs): ) -def fill_stage_dependencies(stage, deps=None, erepo=None, params=None, fs_config=None): +def fill_stage_dependencies( + stage, deps=None, erepo=None, params=None, fs_config=None, db=None +): from dvc.dependency import loads_from, loads_params assert not stage.deps stage.deps = [] - stage.deps += loads_from(stage, deps or [], erepo=erepo, fs_config=fs_config) + stage.deps += loads_from(stage, deps or [], erepo=erepo, fs_config=fs_config, db=db) stage.deps += loads_params(stage, params or []) diff --git a/dvc/utils/db.py b/dvc/utils/db.py new file mode 100644 index 0000000000..00e75c381a --- /dev/null +++ b/dvc/utils/db.py @@ -0,0 +1,252 @@ +import os +from contextlib import contextmanager +from importlib.util import find_spec +from typing import ( + TYPE_CHECKING, + List, + Optional, + Sequence, + Union, +) + +from funcy import cut_prefix, identity + +from dvc.exceptions import DvcException +from dvc.log import logger + +from . import packaging + +if TYPE_CHECKING: + from agate import Table + from dbt.config.profile import Profile + from dbt.contracts.results import RunResult + from rich.status import Status + +logger = logger.getChild(__name__) + + +class DbtInternalError(DvcException): + pass + + +def _check_dbt(action: Optional[str]): + if not (find_spec("dbt") and find_spec("dbt.cli")): + action = f" {action}" if action else "" + raise DvcException(f"Could not run{action}. dbt-core is not installed") + packaging.check_required_version(pkg="dbt-core") + + +@contextmanager +def check_dbt(action: Optional[str]): + _check_dbt(action) + yield + + +def _ref( + name: str, + package: Optional[str] = None, + version: Optional[int] = None, +) -> str: + parts: List[str] = [] + if package: + parts.append(repr(package)) + + parts.append(repr(name)) + if version: + parts.append(f"{version=}") + + inner = ",".join(parts) + return "{{ ref(" + inner + ") }}" + + +def _kw_to_cmd_args(**kwargs: Union[None, bool, str, int]) -> List[str]: + args: List[str] = [] + for key, value in kwargs.items(): + key = key.replace("_", "-") + if value is None: + continue # skip creating a flag in this case + if value is True: + args.append(f"--{key}") + elif value is False: + args.append(f"--no-{key}") + else: + args.extend([f"--{key}", str(value)]) + return args + + +def _dbt_invoke(*posargs: str, quiet: bool = True, **kw: Union[None, bool, str, int]): + from dbt.cli.main import dbtRunner + + args = _kw_to_cmd_args(quiet=quiet or None) # global options + args.extend([*posargs, *_kw_to_cmd_args(**kw)]) + + runner = dbtRunner() + result = runner.invoke(args) + if result.success: + return result.result + raise DbtInternalError(f"failed to run dbt {posargs[0]}") from result.exception + + +def _dbt_show( + inline: Optional[str] = None, + limit: int = -1, + profile: Optional[str] = None, + target: Optional[str] = None, +) -> "Table": + from dbt.contracts.results import RunExecutionResult + + result = _dbt_invoke( + "show", + inline=inline, + limit=limit, + profile=profile, + target=target, + ) + assert isinstance(result, RunExecutionResult) + + run_results: Sequence["RunResult"] = result.results + run_result, *_ = run_results + assert run_result.agate_table is not None + return run_result.agate_table + + +@check_dbt("model") +def get_model( + name: str, + package: Optional[str] = None, + version: Optional[int] = None, + profile: Optional[str] = None, + target: Optional[str] = None, +) -> "Table": + model = _ref(name, package, version=version) + q = f"select * from {model}" # noqa: S608 + return _dbt_show( + inline=q, + profile=profile, + target=target, + ) + + +def _profiles_dir(project_dir: Optional[str] = None) -> str: + from dbt.cli.resolvers import default_profiles_dir + + if profiles_dir := os.getenv("DBT_PROFILES_DIR"): + return profiles_dir + if project_dir and os.path.isfile(os.path.join(project_dir, "profiles.yml")): + return project_dir + return os.fspath(default_profiles_dir()) + + +@contextmanager +def _global_dbt_flags( + profiles_dir: str, + project_dir: str, + target: Optional[str] = None, +): + from argparse import Namespace + + from dbt import flags + + prev = flags.get_flags() + try: + args = Namespace( + use_colors=True, + project_dir=project_dir, + profiles_dir=profiles_dir, + target=target, + ) + flags.set_from_args(args, None) + yield flags.get_flags() + finally: + flags.set_flags(prev) + + +def _get_profile_or( + project_dir: Optional[str], profile: Optional[str], target: Optional[str] +) -> "Profile": + from dbt.config.profile import Profile + from dbt.config.renderer import ProfileRenderer + from dbt.config.runtime import load_profile + + if project_dir and os.path.isfile(os.path.join(project_dir, "dbt_project.yml")): + return load_profile( + project_dir, {}, profile_name_override=profile, target_override=target + ) + + if not profile: + raise DvcException("No profile specified to query from.") + + renderer = ProfileRenderer({}) + return Profile.render(renderer, profile, target_override=target) + + +@contextmanager +def _handle_profile_parsing_error(): + from dbt.exceptions import DbtRuntimeError + + try: + yield + except DbtRuntimeError as e: + cause = e.__cause__ is not None and e.__cause__.__context__ + if isinstance(cause, ModuleNotFoundError) and ( + adapter := cut_prefix(cause.name, "dbt.adapters.") + ): + # DbtRuntimeError is very noisy, so send it to debug + logger.debug("", exc_info=True) + raise DvcException(f"dbt-{adapter} dependency is missing") from cause + raise DvcException("failed to read connection profiles") from e + + +@contextmanager +def _suppress_status(status: Optional["Status"]): + if status: + status.stop() + + yield + if status: + status.start() + + +@check_dbt("query") +def execute_sql( + sql: str, + profiles_dir: str, + project_dir: Optional[str], + profile: Optional[str], + target: Optional[str] = None, + status: Optional["Status"] = None, +) -> "Table": + from dbt.adapters import factory as adapters_factory + from dbt.adapters.sql import SQLAdapter + + flags = _global_dbt_flags(profiles_dir, os.getcwd(), target=target) + update_status = status.update if status else identity + + with flags, adapters_factory.adapter_management(): + # likely invalid connection profile or no adapter + with _handle_profile_parsing_error(), _suppress_status(status): + profile_obj = _get_profile_or(project_dir, profile, target) + + adapters_factory.register_adapter(profile_obj) # type: ignore[arg-type] + adapter = adapters_factory.get_adapter(profile_obj) # type: ignore[arg-type] + + assert isinstance(adapter, SQLAdapter) + with adapter.connection_named("debug"): + update_status("Testing Connection") + try: + adapter.debug_query() + except Exception as exc: + if status: + status.stop() + logger.exception( + "dvc was unable to connect to the specified database. " + "Please check your database credentials and try again.", + exc_info=False, + ) + raise DvcException("The database returned the following error") from exc + + with adapter.connection_named("execute"): + update_status("Executing query") + exec_resp = adapter.execute(sql, fetch=True) + _, table = exec_resp + return table diff --git a/dvc/utils/packaging.py b/dvc/utils/packaging.py new file mode 100644 index 0000000000..776901771b --- /dev/null +++ b/dvc/utils/packaging.py @@ -0,0 +1,32 @@ +import logging + +from funcy import once_per_args + +from dvc.log import logger + +logger = logger.getChild(__name__) + + +@once_per_args +def check_required_version(pkg: str, dist: str = "dvc", log_level=logging.WARNING): + from importlib import metadata + + from packaging.requirements import InvalidRequirement, Requirement + + try: + reqs = { + r.name: r.specifier for r in map(Requirement, metadata.requires(dist) or []) + } + version = metadata.version(pkg) + except (metadata.PackageNotFoundError, InvalidRequirement): + return + + specifier = reqs.get(pkg) + if specifier and version and version not in specifier: + logger.log( + log_level, + "%s%s is required, but you have %r installed which is incompatible.", + pkg, + specifier, + version, + ) diff --git a/pyproject.toml b/pyproject.toml index 41db83b46a..ef8f9f34db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,7 +71,8 @@ dependencies = [ [project.optional-dependencies] all = ["dvc[azure,gdrive,gs,hdfs,oss,s3,ssh,webdav,webhdfs]"] azure = ["dvc-azure>=2.23.0"] -dev = ["dvc[azure,gdrive,gs,hdfs,lint,oss,s3,ssh,tests,webdav,webhdfs]"] +db = ["dbt-core>=1.5"] +dev = ["dvc[azure,gdrive,gs,hdfs,lint,oss,s3,ssh,tests,webdav,webhdfs,db]"] gdrive = ["dvc-gdrive==2.20"] gs = ["dvc-gs==2.22.1"] hdfs = ["dvc-hdfs==2.19"] @@ -207,6 +208,7 @@ warn_unused_configs = true [[tool.mypy.overrides]] ignore_missing_imports = true module = [ + "agate.*", "celery.*", "configobj.*", "dpath.*", diff --git a/tests/unit/stage/test_stage.py b/tests/unit/stage/test_stage.py index 60cad2ae90..2139fa3c07 100644 --- a/tests/unit/stage/test_stage.py +++ b/tests/unit/stage/test_stage.py @@ -69,7 +69,7 @@ def test_stage_update(dvc, mocker): is_repo_import.return_value = True with dvc.lock: stage.update() - reproduce.assert_called_once_with(no_download=None, jobs=None) + reproduce.assert_called_once_with(no_download=None, jobs=None, force=False) is_repo_import.return_value = False with pytest.raises(StageUpdateError):