Skip to content

Commit

Permalink
generate .dvc file
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Nov 1, 2023
1 parent c416234 commit 0d11e39
Show file tree
Hide file tree
Showing 12 changed files with 293 additions and 45 deletions.
35 changes: 9 additions & 26 deletions dvc/commands/imp_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,38 +4,21 @@
from dvc.cli import completion
from dvc.cli.command import CmdBase
from dvc.cli.utils import append_doc_link
from dvc.ui import ui

logger = logging.getLogger(__name__)


class CmdImportDb(CmdBase):
def run(self):
from fal.dbt import FalDbt
from funcy import print_durations

from dvc.repo.open_repo import _cached_clone

clone = _cached_clone(self.args.url, self.args.rev)
faldbt = FalDbt(profiles_dir="~/.dbt", project_dir=clone)

if not self.args.sql:
name = self.args.to_materialize
out = self.args.out or f"{name}.csv"
with print_durations(f"ref {name}"), ui.status(f"Downloading {name}"):
model = faldbt.ref(name)
else:
query = self.args.to_materialize
out = self.args.out or "result.csv"
with print_durations(f"execute_sql {query}"), ui.status(
"Executing sql query"
):
model = faldbt.execute_sql(query)

with print_durations(f"to_csv {out}"), ui.status(f"Saving to {out}"):
model.to_csv(out)

ui.write(f"Saved file to {out}", styled=True)
self.repo.imp_db(
url=self.args.url,
target=self.args.to_materialize,
type="query" if self.args.sql else "model",
out=self.args.out,
rev=self.args.rev,
force=self.args.force,
)
return 0


def add_parser(subparsers, parent_parser):
Expand Down
10 changes: 9 additions & 1 deletion dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from dvc.output import ARTIFACT_SCHEMA, DIR_FILES_SCHEMA, Output

from .base import Dependency
from .db import DbDependency
from .param import ParamsDependency
from .repo import RepoDependency

Expand All @@ -16,10 +17,15 @@
**RepoDependency.REPO_SCHEMA,
Output.PARAM_FILES: [DIR_FILES_SCHEMA],
Output.PARAM_FS_CONFIG: dict,
**DbDependency.DB_SCHEMA,
}


def _get(stage, p, info, **kwargs):
if info and info.get(DbDependency.PARAM_DB):
repo = info.pop(RepoDependency.PARAM_REPO)
db = info.pop(DbDependency.PARAM_DB)
return DbDependency(repo, stage, p, info, db=db)
if info and info.get(RepoDependency.PARAM_REPO):
repo = info.pop(RepoDependency.PARAM_REPO)
return RepoDependency(repo, stage, p, info)
Expand All @@ -44,9 +50,11 @@ def loadd_from(stage, d_list):
return ret


def loads_from(stage, s_list, erepo=None, fs_config=None):
def loads_from(stage, s_list, erepo=None, fs_config=None, db=None):
assert isinstance(s_list, list)
info = {RepoDependency.PARAM_REPO: erepo} if erepo else {}
if db:
info.update({DbDependency.PARAM_DB: db})
return [
_get(
stage,
Expand Down
164 changes: 164 additions & 0 deletions dvc/dependency/db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,164 @@
import logging
import os
from contextlib import contextmanager, redirect_stdout
from typing import TYPE_CHECKING, Any, Dict, Union

from dvc.scm import SCM

from .base import Dependency

if TYPE_CHECKING:
from dvc.stage import Stage

logger = logging.getLogger(__name__)


def log_streams():
return redirect_stdout(StreamLogger(logging.DEBUG))


class StreamLogger:
def __init__(self, level):
self.level = level

def write(self, message):
logger.log(self.level, message)


class DbDependency(Dependency):
PARAM_DB = "db"
PARAM_PROFILE = "profile"
PARAM_MODEL = "model"
PARAM_QUERY = "query"
PARAM_EXPORT_FORMAT = "export_format"

DB_SCHEMA = {
PARAM_DB: {
PARAM_MODEL: str,
PARAM_QUERY: str,
PARAM_PROFILE: str,
PARAM_EXPORT_FORMAT: str,
}
}

def __init__(
self, def_repo: Dict[str, Any], stage: "Stage", *args, **kwargs
): # pylint: disable=super-init-not-called
self.repo = stage.repo
self.def_repo = def_repo
self.db_info = kwargs.pop("db", {})
self.fs = None
self.fs_path = None
self.def_path = None # type: ignore[assignment]
# super().__init__(stage, *args, **kwargs)

def __repr__(self):
return "{}:{}".format(
self.__class__.__name__,
"".join(f"{k}=={v}" for k, v in {**self.db_info, **self.def_repo}.items()),
)

def __str__(self):
from .repo import RepoDependency

repo = self.def_repo.get(RepoDependency.PARAM_REPO)
rev = self.def_repo.get(RepoDependency.PARAM_REV)

db = self.db_info.get(self.PARAM_MODEL)
if not db:
from dvc.utils.humanize import truncate_text

db = truncate_text(self.db_info.get(self.PARAM_QUERY, "[query]"), 50)

repo_info = ""
if repo:
repo_info += repo
if rev:
repo_info += f"@{rev}"
return db + (f"({repo_info})" if repo_info else "")

def workspace_status(self):
return False

def status(self):
return self.workspace_status()

def save(self):
pass

def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]:
from .repo import RepoDependency

return {
self.PARAM_DB: self.db_info,
# pylint: disable-next=protected-access
RepoDependency.PARAM_REPO: RepoDependency._dump_def_repo(self.def_repo),
}

def update(self, rev=None):
from dvc.repo.open_repo import _cached_clone

from .repo import RepoDependency

if rev:
self.def_repo[RepoDependency.PARAM_REV] = rev
else:
rev = self.def_repo.get(RepoDependency.PARAM_REV)

url = self.def_repo.get(RepoDependency.PARAM_URL)
repo_root = self.repo.root_dir if self.repo else os.getcwd()
project_dir = _cached_clone(url, rev) if url else repo_root
self.def_repo[RepoDependency.PARAM_REV_LOCK] = SCM(project_dir).get_rev()

def download(self, to, jobs=None, export_format=None): # noqa: ARG002
from dvc.repo.open_repo import _cached_clone
from dvc.ui import ui

from .repo import RepoDependency

url = self.def_repo.get(RepoDependency.PARAM_URL)
rev = self.def_repo.get(RepoDependency.PARAM_REV)
rev_lock = self.def_repo.get(RepoDependency.PARAM_REV_LOCK)

repo_root = self.repo.root_dir if self.repo else os.getcwd()
project_dir = _cached_clone(url, rev or rev_lock) if url else repo_root

self.def_repo[RepoDependency.PARAM_REV_LOCK] = SCM(project_dir).get_rev()

self._download_dbt(project_dir, to, export_format=export_format)

ui.write(f"Saved file to {to}", styled=True)

def _download_dbt(self, project_dir, to, export_format=None):
from funcy import log_durations

from dvc.ui import ui

with log_streams():
from fal.dbt import FalDbt

faldbt = FalDbt(profiles_dir="~/.dbt", project_dir=project_dir)

@contextmanager
def log_status(msg, log=logger.debug):
with log_durations(log, msg), ui.status(msg):
yield

if model := self.db_info.get(self.PARAM_MODEL):
with log_status(f"Downloading {model}"), log_streams():
model = faldbt.ref(model)
elif query := self.db_info.get(self.PARAM_QUERY):
with log_status("Executing sql query"), log_streams():
model = faldbt.execute_sql(query)
else:
raise AssertionError("neither a query nor a model received")

export_format = export_format or self.db_info.get(
self.PARAM_EXPORT_FORMAT, "csv"
)
exporter = {
"csv": model.to_csv,
"json": model.to_json,
}
with log_status(f"Saving to {to}"), log_streams():
return exporter[export_format](to.fs_path)
25 changes: 14 additions & 11 deletions dvc/dependency/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,28 +64,31 @@ def save(self):
if self.def_repo.get(self.PARAM_REV_LOCK) is None:
self.def_repo[self.PARAM_REV_LOCK] = rev

def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]:
repo = {self.PARAM_URL: self.def_repo[self.PARAM_URL]}
@classmethod
def _dump_def_repo(cls, def_repo) -> Dict[str, str]:
repo = {cls.PARAM_URL: def_repo[cls.PARAM_URL]}

rev = self.def_repo.get(self.PARAM_REV)
rev = def_repo.get(cls.PARAM_REV)
if rev:
repo[self.PARAM_REV] = self.def_repo[self.PARAM_REV]
repo[cls.PARAM_REV] = def_repo[cls.PARAM_REV]

rev_lock = self.def_repo.get(self.PARAM_REV_LOCK)
rev_lock = def_repo.get(cls.PARAM_REV_LOCK)
if rev_lock:
repo[self.PARAM_REV_LOCK] = rev_lock
repo[cls.PARAM_REV_LOCK] = rev_lock

config = self.def_repo.get(self.PARAM_CONFIG)
config = def_repo.get(cls.PARAM_CONFIG)
if config:
repo[self.PARAM_CONFIG] = config
repo[cls.PARAM_CONFIG] = config

remote = self.def_repo.get(self.PARAM_REMOTE)
remote = def_repo.get(cls.PARAM_REMOTE)
if remote:
repo[self.PARAM_REMOTE] = remote
repo[cls.PARAM_REMOTE] = remote
return repo

def dumpd(self, **kwargs) -> Dict[str, Union[str, Dict[str, str]]]:
return {
self.PARAM_PATH: self.def_path,
self.PARAM_REPO: repo,
self.PARAM_REPO: self._dump_def_repo(self.def_repo),
}

def update(self, rev: Optional[str] = None):
Expand Down
2 changes: 1 addition & 1 deletion dvc/output.py
Original file line number Diff line number Diff line change
Expand Up @@ -1528,7 +1528,7 @@ def _merge_dir_version_meta(self, other: "Output"):
ARTIFACT_SCHEMA = {
**CHECKSUMS_SCHEMA,
**META_SCHEMA,
Required(Output.PARAM_PATH): str,
Output.PARAM_PATH: str,
Output.PARAM_PERSIST: bool,
Output.PARAM_CLOUD: CLOUD_SCHEMA,
Output.PARAM_HASH: str,
Expand Down
1 change: 1 addition & 0 deletions dvc/repo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ class Repo:
from dvc.repo.get import get as _get # type: ignore[misc]
from dvc.repo.get_url import get_url as _get_url # type: ignore[misc]
from dvc.repo.imp import imp # type: ignore[misc]
from dvc.repo.imp_db import imp_db # type: ignore[misc]
from dvc.repo.imp_url import imp_url # type: ignore[misc]
from dvc.repo.install import install # type: ignore[misc]
from dvc.repo.ls import ls as _ls # type: ignore[misc]
Expand Down
4 changes: 4 additions & 0 deletions dvc/repo/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@ def build_graph(stages, outs_trie=None):
for stage in stages:
if stage.is_repo_import:
continue
if stage.is_db_import:
continue

for dep in stage.deps:
dep_key = dep.fs.path.parts(dep.fs_path)
Expand All @@ -160,6 +162,8 @@ def build_outs_graph(graph, outs_trie):
for stage in graph.nodes():
if stage.is_repo_import:
continue
if stage.is_db_import:
continue
for dep in stage.deps:
dep_key = dep.fs.path.parts(dep.fs_path)
overlapping = [n.value for n in outs_trie.prefixes(dep_key)]
Expand Down
60 changes: 60 additions & 0 deletions dvc/repo/imp_db.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
from typing import TYPE_CHECKING, Optional

from dvc.exceptions import OutputDuplicationError
from dvc.repo.scm_context import scm_context
from dvc.utils import resolve_output, resolve_paths

if TYPE_CHECKING:
from . import Repo

from . import locked


@locked
@scm_context
def imp_db(
self: "Repo",
url: str,
target: str,
type: str = "model", # noqa: A002, pylint: disable=redefined-builtin
out: Optional[str] = None,
rev: Optional[str] = None,
frozen: bool = True,
force: bool = False,
export_format: str = "csv",
):
erepo = {"url": url}
if rev:
erepo["rev"] = rev

assert type in ("model", "query")
assert export_format in ("csv", "json")
if not out:
out = "results.csv" if type == "query" else f"{target}.{export_format}"

db = {type: target, "export_format": export_format}
out = resolve_output(url, out, force=force)
path, wdir, out = resolve_paths(self, out, always_local=True)
stage = self.stage.create(
single_stage=True,
validate=False,
fname=path,
wdir=wdir,
deps=[url],
outs=[out],
erepo=erepo,
fs_config=None,
db=db,
)

try:
self.check_graph(stages={stage})
except OutputDuplicationError as exc:
raise OutputDuplicationError( # noqa: B904
exc.output, set(exc.stages) - {stage}
)

stage.run()
stage.frozen = frozen
stage.dump()
return stage
Loading

0 comments on commit 0d11e39

Please sign in to comment.