Skip to content

Commit

Permalink
Fix dlc parsing on Benchmark __init__
Browse files Browse the repository at this point in the history
  • Loading branch information
ptoupas committed Jan 14, 2025
1 parent 4d3bc5b commit e8bc974
Show file tree
Hide file tree
Showing 3 changed files with 11 additions and 6 deletions.
2 changes: 2 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
default_language_version:
python: python3
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.2
Expand Down
4 changes: 2 additions & 2 deletions modelconverter/packages/base_benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@


class Benchmark(ABC):
VALID_EXTENSIONS = (".tar.xz", ".blob")
VALID_EXTENSIONS = (".tar.xz", ".blob", ".dlc")
HUB_MODEL_PATTERN = re.compile(r"^(?:([^/]+)/)?([^:]+):(.+)$")

def __init__(
Expand All @@ -40,7 +40,7 @@ def __init__(
if not hub_match:
raise ValueError(
"Invalid 'model-path' format. Expected either:\n"
"- Model file path: path/to/model.blob or path/to/model.tar.xz\n"
"- Model file path: path/to/model.blob, path/to/model.dlc or path/to/model.tar.xz\n"
"- HubAI model slug: [team_name/]model_name:variant"
)
team_name, model_name, model_variant = hub_match.groups()
Expand Down
11 changes: 7 additions & 4 deletions modelconverter/utils/hubai_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,13 @@ def is_hubai_available(model_name: str, model_variant: str) -> bool:

model_variants = []
for is_public in [True, False]:
model_variants += Request.get(
"modelVersions/",
params={"model_id": model_id, "is_public": is_public},
)
try:
model_variants += Request.get(
"modelVersions/",
params={"model_id": model_id, "is_public": is_public},
)
except Exception:
pass

for version in model_variants:
if f"{model_name}:{version['variant_slug']}" == model_slug:
Expand Down

0 comments on commit e8bc974

Please sign in to comment.