Skip to content

Commit

Permalink
Add model folder to the unzip path (#1445)
Browse files Browse the repository at this point in the history
* Add model folder to the unzip path

* Handle cases where zipped model either has no extra directory

* Add test

* Fix-up test and implementation

* Manually lint
  • Loading branch information
roomrys authored Aug 10, 2023
1 parent d61a184 commit 5ba6bc1
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 11 deletions.
8 changes: 7 additions & 1 deletion sleap/nn/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4824,6 +4824,7 @@ def unpack_sleap_model(model_path):
# Uncompress ZIP packaged models.
tmp_dirs = []
for i, model_path in enumerate(model_paths):
mp = Path(model_path)
if model_path.endswith(".zip"):
# Create temp dir on demand.
tmp_dir = tempfile.TemporaryDirectory()
Expand All @@ -4834,7 +4835,12 @@ def unpack_sleap_model(model_path):

# Extract and replace in the list.
shutil.unpack_archive(model_path, extract_dir=tmp_dir.name)
model_paths[i] = tmp_dir.name
unzipped_mp = Path(tmp_dir.name, mp.name).with_suffix("")
if Path(unzipped_mp, "best_model.h5").exists():
unzipped_model_path = str(unzipped_mp)
else:
unzipped_model_path = str(unzipped_mp.parent)
model_paths[i] = unzipped_model_path

return model_paths, tmp_dirs

Expand Down
62 changes: 52 additions & 10 deletions tests/nn/test_inference.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
import ast
import json
import zipfile
from pathlib import Path
from typing import cast
import pytest

import numpy as np
import json
from sleap.io.dataset import Labels
from sleap.nn.tracking import FlowCandidateMaker, Tracker
import pytest
import tensorflow as tf
import sleap
from numpy.testing import assert_array_equal, assert_allclose
from pathlib import Path
import tensorflow_hub as hub
from numpy.testing import assert_array_equal, assert_allclose

import sleap
from sleap.gui.learning import runners
from sleap.io.dataset import Labels
from sleap.nn.data.confidence_maps import (
make_confmaps,
make_grid_vectors,
make_multi_confmaps,
)

from sleap.nn.inference import (
InferenceLayer,
InferenceModel,
Expand Down Expand Up @@ -49,10 +51,9 @@
main as sleap_track,
export_cli as sleap_export,
)
from sleap.nn.tracking import FlowCandidateMaker, Tracker


from sleap.gui.learning import runners

sleap.nn.system.use_cpu_only()


Expand Down Expand Up @@ -832,6 +833,47 @@ def test_topdown_multiclass_predictor_high_threshold(
assert len(labels_pr[0].instances) == 0


def zip_directory_with_itself(src_dir, output_path):
"""Zip a directory, including the directory itself.
Args:
src_dir: Path to directory to zip.
output_path: Path to output zip file.
"""

src_path = Path(src_dir)
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for file_path in src_path.rglob("*"):
arcname = src_path.name / file_path.relative_to(src_path)
zipf.write(file_path, arcname)


def zip_directory_contents(src_dir, output_path):
"""Zip the contents of a directory, not the directory itself.
Args:
src_dir: Path to directory to zip.
output_path: Path to output zip file.
"""

src_path = Path(src_dir)
with zipfile.ZipFile(output_path, "w", zipfile.ZIP_DEFLATED) as zipf:
for file_path in src_path.rglob("*"):
arcname = file_path.relative_to(src_path)
zipf.write(file_path, arcname)


@pytest.mark.parametrize(
"zip_func", [zip_directory_with_itself, zip_directory_contents]
)
def test_load_model_zipped(tmpdir, min_centroid_model_path, zip_func):
mp = Path(min_centroid_model_path)
zip_dir = Path(tmpdir, mp.name).with_name(mp.name + ".zip")
zip_func(mp, zip_dir)

predictor = load_model(str(zip_dir))


@pytest.mark.parametrize("resize_input_shape", [True, False])
@pytest.mark.parametrize(
"model_fixture_name",
Expand Down

0 comments on commit 5ba6bc1

Please sign in to comment.