Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
skshetry committed Nov 7, 2023
1 parent d4cef6a commit 7d97c19
Show file tree
Hide file tree
Showing 8 changed files with 433 additions and 123 deletions.
49 changes: 35 additions & 14 deletions dvc/commands/imp_db.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,19 @@

class CmdImportDb(CmdBase):
def run(self):
if not (self.args.sql or self.args.model):
raise argparse.ArgumentTypeError("Either of --sql or --model is required.")

self.repo.imp_db(
url=self.args.url,
target=self.args.to_materialize,
type="query" if self.args.sql else "model",
out=self.args.out,
rev=self.args.rev,
project_dir=self.args.project_dir,
sql=self.args.sql,
model=self.args.model,
profile=self.args.profile,
target=self.args.target,
export_format=self.args.export_format,
out=self.args.out,
force=self.args.force,
)
return 0
Expand All @@ -35,16 +42,36 @@ def add_parser(subparsers, parent_parser):
formatter_class=argparse.RawTextHelpFormatter,
)
import_parser.add_argument(
"url", help="Location of DVC or Git repository to download from"
"--url", help="Location of DVC or Git repository to download from"
)
import_parser.add_argument(
"to_materialize", help="Name of the dbt model or SQL query (if --sql)"
"--rev",
nargs="?",
help="Git revision (e.g. SHA, branch, tag)",
metavar="<commit>",
)
import_parser.add_argument(
"--project-dir", nargs="?", help="Subdirectory to the dbt project location"
)

group = import_parser.add_mutually_exclusive_group()
group.add_argument(
"--sql",
help="is a sql query",
action="store_true",
default=False,
help="SQL query",
)
group.add_argument(
"--model",
help="Model name to download",
)
import_parser.add_argument("--profile", help="Profile to use")
import_parser.add_argument("--target", help="Target to use")
import_parser.add_argument(
"--export-format",
default="csv",
const="csv",
nargs="?",
choices=["csv", "json"],
help="Export format",
)
import_parser.add_argument(
"-o",
Expand All @@ -60,11 +87,5 @@ def add_parser(subparsers, parent_parser):
default=False,
help="Override destination file or folder if exists.",
)
import_parser.add_argument(
"--rev",
nargs="?",
help="Git revision (e.g. SHA, branch, tag)",
metavar="<commit>",
)

import_parser.set_defaults(func=CmdImportDb)
17 changes: 10 additions & 7 deletions dvc/dependency/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from dvc.output import ARTIFACT_SCHEMA, DIR_FILES_SCHEMA, Output

from .base import Dependency
from .db import DbDependency
from .db import DB_SCHEMA, PARAM_DB, DbDependency, DbtDependency
from .param import ParamsDependency
from .repo import RepoDependency

Expand All @@ -15,17 +15,20 @@
SCHEMA: Mapping[str, Any] = {
**ARTIFACT_SCHEMA,
**RepoDependency.REPO_SCHEMA,
**DB_SCHEMA,
Output.PARAM_FILES: [DIR_FILES_SCHEMA],
Output.PARAM_FS_CONFIG: dict,
**DbDependency.DB_SCHEMA,
}


def _get(stage, p, info, **kwargs):
if info and info.get(DbDependency.PARAM_DB):
repo = info.pop(RepoDependency.PARAM_REPO)
db = info.pop(DbDependency.PARAM_DB)
return DbDependency(repo, stage, p, info, db=db)
db = info.get(PARAM_DB, {})
if DbDependency.PARAM_QUERY in db:
return DbDependency(stage, info)
if db:
repo = info.pop(RepoDependency.PARAM_REPO, None)
return DbtDependency(repo, stage, info)

if info and info.get(RepoDependency.PARAM_REPO):
repo = info.pop(RepoDependency.PARAM_REPO)
return RepoDependency(repo, stage, p, info)
Expand Down Expand Up @@ -54,7 +57,7 @@ def loads_from(stage, s_list, erepo=None, fs_config=None, db=None):
assert isinstance(s_list, list)
info = {RepoDependency.PARAM_REPO: erepo} if erepo else {}
if db:
info.update({DbDependency.PARAM_DB: db})
info.update({"db": db})
return [
_get(
stage,
Expand Down
Loading

0 comments on commit 7d97c19

Please sign in to comment.