-
Notifications
You must be signed in to change notification settings - Fork 25
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Restore "default" Property for Recipes (#366)
* class for recipes * fix failing test * check for Recipes instead of list * missed list instance * allow for parsing of custom default recipes * update comment * PR comments * PR comments
- Loading branch information
Showing
6 changed files
with
165 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -17,3 +17,4 @@ | |
|
||
# flake8: noqa | ||
from .file import * | ||
from .recipes import * |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import logging | ||
from typing import Dict, List, Optional, Union | ||
|
||
from sparsezoo.objects import File | ||
|
||
|
||
__all__ = ["Recipes"] | ||
|
||
_LOGGER = logging.getLogger(__name__) | ||
|
||
|
||
class Recipes: | ||
""" | ||
Object to store a list of recipes for a downloaded model and pull the default | ||
:param recipes: list of recipes to store | ||
:param stub_params: dictionary that may contain custom default recipes names | ||
""" | ||
|
||
_RECIPE_DEFAULT_NAME = "recipe.md" | ||
|
||
def __init__( | ||
self, | ||
recipes: Optional[Union[File, List[File]]], | ||
stub_params: Dict[str, str] = {}, | ||
): | ||
if recipes is None: | ||
recipes = [] | ||
if isinstance(recipes, File): | ||
recipes = [recipes] | ||
self._recipes = recipes | ||
|
||
self._default_recipe_name = self._RECIPE_DEFAULT_NAME | ||
custom_default = stub_params.get("recipe_type") or stub_params.get("recipe") | ||
if custom_default is not None: | ||
self._default_recipe_name = "recipe_" + custom_default | ||
|
||
@property | ||
def recipes(self) -> List: | ||
""" | ||
:return: The full list of recipes | ||
""" | ||
return self._recipes | ||
|
||
@property | ||
def default(self) -> File: | ||
""" | ||
:return: The default recipe in the recipe list | ||
""" | ||
for recipe in self._recipes: | ||
if recipe.name.startswith(self._default_recipe_name): | ||
return recipe | ||
|
||
# fallback to first recipe in list | ||
_LOGGER.warning( | ||
"No default recipe {self._default_recipe_name} found, falling back to" | ||
"first listed recipe" | ||
) | ||
return self._recipes[0] | ||
|
||
def get_recipe_by_name(self, recipe_name: str) -> Union[File, None]: | ||
""" | ||
Returns the File for the recipe matching the name recipe_name if it exists | ||
:param recipe_name: recipe filename to search for | ||
:return: File with the name recipe_name, or None if it doesn't exist | ||
""" | ||
|
||
for recipe in self._recipes: | ||
if recipe.name == recipe_name: | ||
return recipe | ||
|
||
return None # no matching recipe found |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
# Copyright (c) 2021 - present / Neuralmagic, Inc. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, | ||
# software distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import tempfile | ||
|
||
from sparsezoo.model import Model | ||
|
||
|
||
def test_recipe_getters(): | ||
stub_with_multiple_recipes = "zoo:bert-base-wikipedia_bookcorpus-pruned90" | ||
temp_dir = tempfile.TemporaryDirectory(dir="/tmp") | ||
model = Model(stub_with_multiple_recipes, temp_dir.name) | ||
|
||
default_recipe = model.recipes.default | ||
assert default_recipe.name == "recipe.md" | ||
|
||
all_recipes = model.recipes.recipes | ||
assert len(all_recipes) == 4 | ||
|
||
recipe_name = "recipe_transfer_text_classification.md" | ||
found_by_name = model.recipes.get_recipe_by_name(recipe_name) | ||
assert found_by_name.name == recipe_name | ||
|
||
found_by_name = model.recipes.get_recipe_by_name("does_not_exist.md") | ||
assert found_by_name is None | ||
|
||
|
||
def test_custom_default(): | ||
custom_default_name = "transfer_text_classification" | ||
stub_with_multiple_recipes = ( | ||
"zoo:bert-base-wikipedia_bookcorpus-pruned90?recipe={}".format( | ||
custom_default_name | ||
) | ||
) | ||
temp_dir = tempfile.TemporaryDirectory(dir="/tmp") | ||
model = Model(stub_with_multiple_recipes, temp_dir.name) | ||
|
||
expected_default_name = "recipe_" + custom_default_name + ".md" | ||
|
||
default_recipe = model.recipes.default | ||
assert default_recipe.name == expected_default_name |