-
Notifications
You must be signed in to change notification settings - Fork 1.2k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
293 additions
and
45 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,164 @@ | ||
import logging | ||
import os | ||
from contextlib import contextmanager, redirect_stdout | ||
from typing import TYPE_CHECKING, Any, Dict, Union | ||
|
||
from dvc.scm import SCM | ||
|
||
from .base import Dependency | ||
|
||
if TYPE_CHECKING: | ||
from dvc.stage import Stage | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
def log_streams(): | ||
return redirect_stdout(StreamLogger(logging.DEBUG)) | ||
|
||
|
||
class StreamLogger: | ||
def __init__(self, level): | ||
self.level = level | ||
|
||
def write(self, message): | ||
logger.log(self.level, message) | ||
|
||
|
||
class DbDependency(Dependency): | ||
PARAM_DB = "db" | ||
PARAM_PROFILE = "profile" | ||
PARAM_MODEL = "model" | ||
PARAM_QUERY = "query" | ||
PARAM_EXPORT_FORMAT = "export_format" | ||
|
||
DB_SCHEMA = { | ||
PARAM_DB: { | ||
PARAM_MODEL: str, | ||
PARAM_QUERY: str, | ||
PARAM_PROFILE: str, | ||
PARAM_EXPORT_FORMAT: str, | ||
} | ||
} | ||
|
||
def __init__( | ||
self, def_repo: Dict[str, Any], stage: "Stage", *args, **kwargs | ||
): # pylint: disable=super-init-not-called | ||
self.repo = stage.repo | ||
self.def_repo = def_repo | ||
self.db_info = kwargs.pop("db", {}) | ||
self.fs = None | ||
self.fs_path = None | ||
self.def_path = None # type: ignore[assignment] | ||
# super().__init__(stage, *args, **kwargs) | ||
|
||
def __repr__(self): | ||
return "{}:{}".format( | ||
self.__class__.__name__, | ||
"".join(f"{k}=={v}" for k, v in {**self.db_info, **self.def_repo}.items()), | ||
) | ||
|
||
def __str__(self): | ||
from .repo import RepoDependency | ||
|
||
repo = self.def_repo.get(RepoDependency.PARAM_REPO) | ||
rev = self.def_repo.get(RepoDependency.PARAM_REV) | ||
|
||
db = self.db_info.get(self.PARAM_MODEL) | ||
if not db: | ||
from dvc.utils.humanize import truncate_text | ||
|
||
db = truncate_text(self.db_info.get(self.PARAM_QUERY, "[query]"), 50) | ||
|
||
repo_info = "" | ||
if repo: | ||
repo_info += repo | ||
if rev: | ||
repo_info += f"@{rev}" | ||
return db + (f"({repo_info})" if repo_info else "") | ||
|
||
def workspace_status(self): | ||
return False | ||
|
||
def status(self): | ||
return self.workspace_status() | ||
|
||
def save(self): | ||
pass | ||
|
||
def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]: | ||
from .repo import RepoDependency | ||
|
||
return { | ||
self.PARAM_DB: self.db_info, | ||
# pylint: disable-next=protected-access | ||
RepoDependency.PARAM_REPO: RepoDependency._dump_def_repo(self.def_repo), | ||
} | ||
|
||
def update(self, rev=None): | ||
from dvc.repo.open_repo import _cached_clone | ||
|
||
from .repo import RepoDependency | ||
|
||
if rev: | ||
self.def_repo[RepoDependency.PARAM_REV] = rev | ||
else: | ||
rev = self.def_repo.get(RepoDependency.PARAM_REV) | ||
|
||
url = self.def_repo.get(RepoDependency.PARAM_URL) | ||
repo_root = self.repo.root_dir if self.repo else os.getcwd() | ||
project_dir = _cached_clone(url, rev) if url else repo_root | ||
self.def_repo[RepoDependency.PARAM_REV_LOCK] = SCM(project_dir).get_rev() | ||
|
||
def download(self, to, jobs=None, export_format=None): # noqa: ARG002 | ||
from dvc.repo.open_repo import _cached_clone | ||
from dvc.ui import ui | ||
|
||
from .repo import RepoDependency | ||
|
||
url = self.def_repo.get(RepoDependency.PARAM_URL) | ||
rev = self.def_repo.get(RepoDependency.PARAM_REV) | ||
rev_lock = self.def_repo.get(RepoDependency.PARAM_REV_LOCK) | ||
|
||
repo_root = self.repo.root_dir if self.repo else os.getcwd() | ||
project_dir = _cached_clone(url, rev or rev_lock) if url else repo_root | ||
|
||
self.def_repo[RepoDependency.PARAM_REV_LOCK] = SCM(project_dir).get_rev() | ||
|
||
self._download_dbt(project_dir, to, export_format=export_format) | ||
|
||
ui.write(f"Saved file to {to}", styled=True) | ||
|
||
def _download_dbt(self, project_dir, to, export_format=None): | ||
from funcy import log_durations | ||
|
||
from dvc.ui import ui | ||
|
||
with log_streams(): | ||
from fal.dbt import FalDbt | ||
|
||
faldbt = FalDbt(profiles_dir="~/.dbt", project_dir=project_dir) | ||
|
||
@contextmanager | ||
def log_status(msg, log=logger.debug): | ||
with log_durations(log, msg), ui.status(msg): | ||
yield | ||
|
||
if model := self.db_info.get(self.PARAM_MODEL): | ||
with log_status(f"Downloading {model}"), log_streams(): | ||
model = faldbt.ref(model) | ||
elif query := self.db_info.get(self.PARAM_QUERY): | ||
with log_status("Executing sql query"), log_streams(): | ||
model = faldbt.execute_sql(query) | ||
else: | ||
raise AssertionError("neither a query nor a model received") | ||
|
||
export_format = export_format or self.db_info.get( | ||
self.PARAM_EXPORT_FORMAT, "csv" | ||
) | ||
exporter = { | ||
"csv": model.to_csv, | ||
"json": model.to_json, | ||
} | ||
with log_status(f"Saving to {to}"), log_streams(): | ||
return exporter[export_format](to.fs_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
from typing import TYPE_CHECKING, Optional | ||
|
||
from dvc.exceptions import OutputDuplicationError | ||
from dvc.repo.scm_context import scm_context | ||
from dvc.utils import resolve_output, resolve_paths | ||
|
||
if TYPE_CHECKING: | ||
from . import Repo | ||
|
||
from . import locked | ||
|
||
|
||
@locked | ||
@scm_context | ||
def imp_db( | ||
self: "Repo", | ||
url: str, | ||
target: str, | ||
type: str = "model", # noqa: A002, pylint: disable=redefined-builtin | ||
out: Optional[str] = None, | ||
rev: Optional[str] = None, | ||
frozen: bool = True, | ||
force: bool = False, | ||
export_format: str = "csv", | ||
): | ||
erepo = {"url": url} | ||
if rev: | ||
erepo["rev"] = rev | ||
|
||
assert type in ("model", "query") | ||
assert export_format in ("csv", "json") | ||
if not out: | ||
out = "results.csv" if type == "query" else f"{target}.{export_format}" | ||
|
||
db = {type: target, "export_format": export_format} | ||
out = resolve_output(url, out, force=force) | ||
path, wdir, out = resolve_paths(self, out, always_local=True) | ||
stage = self.stage.create( | ||
single_stage=True, | ||
validate=False, | ||
fname=path, | ||
wdir=wdir, | ||
deps=[url], | ||
outs=[out], | ||
erepo=erepo, | ||
fs_config=None, | ||
db=db, | ||
) | ||
|
||
try: | ||
self.check_graph(stages={stage}) | ||
except OutputDuplicationError as exc: | ||
raise OutputDuplicationError( # noqa: B904 | ||
exc.output, set(exc.stages) - {stage} | ||
) | ||
|
||
stage.run() | ||
stage.frozen = frozen | ||
stage.dump() | ||
return stage |
Oops, something went wrong.