Skip to content

Commit

Permalink
Merge pull request #1028 from roboflow/feat/new-getWeights-for-yolo
Browse files Browse the repository at this point in the history
Handle new getWeights in RoboflowInferenceModel
  • Loading branch information
PawelPeczek-Roboflow authored Feb 17, 2025
2 parents b9f473a + 7434d7a commit 54bc830
Showing 1 changed file with 53 additions and 23 deletions.
76 changes: 53 additions & 23 deletions inference/core/models/roboflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
from inference.core.roboflow_api import (
ModelEndpointType,
get_from_url,
get_roboflow_instant_model_data,
get_roboflow_model_data,
)
from inference.core.utils.image_utils import load_image
Expand Down Expand Up @@ -264,33 +265,62 @@ def model_artifact_bucket(self):

def download_model_artifacts_from_roboflow_api(self) -> None:
logger.debug("Downloading model artifacts from Roboflow API")
api_data = get_roboflow_model_data(
api_key=self.api_key,
model_id=self.endpoint,
endpoint_type=ModelEndpointType.ORT,
device_id=self.device_id,
)
if "ort" not in api_data.keys():
raise ModelArtefactError(
"Could not find `ort` key in roboflow API model description response."
)
api_data = api_data["ort"]
if "classes" in api_data:
save_text_lines_in_cache(
content=api_data["classes"],
file="class_names.txt",
if self.version_id is not None:
api_data = get_roboflow_model_data(
api_key=self.api_key,
model_id=self.endpoint,
endpoint_type=ModelEndpointType.ORT,
device_id=self.device_id,
)
if "model" not in api_data:
raise ModelArtefactError(
"Could not find `model` key in roboflow API model description response."
if "ort" not in api_data.keys():
raise ModelArtefactError(
"Could not find `ort` key in roboflow API model description response."
)
api_data = api_data["ort"]
if "classes" in api_data:
save_text_lines_in_cache(
content=api_data["classes"],
file="class_names.txt",
model_id=self.endpoint,
)
if "model" not in api_data:
raise ModelArtefactError(
"Could not find `model` key in roboflow API model description response."
)
if "environment" not in api_data:
raise ModelArtefactError(
"Could not find `environment` key in roboflow API model description response."
)
environment = get_from_url(api_data["environment"])
model_weights_response = get_from_url(
api_data["model"], json_response=False
)
if "environment" not in api_data:
raise ModelArtefactError(
"Could not find `environment` key in roboflow API model description response."
else:
api_data = get_roboflow_instant_model_data(
api_key=self.api_key,
model_id=self.endpoint,
)
environment = get_from_url(api_data["environment"])
model_weights_response = get_from_url(api_data["model"], json_response=False)
if (
"modelFiles" not in api_data
or "ort" not in api_data["modelFiles"]
or "model" not in api_data["modelFiles"]["ort"]
):
raise ModelArtefactError(
"Could not find `modelFiles` key or `modelFiles`.`ort` or `modelFiles`.`ort`.`model` key in roboflow API model description response."
)
model_weights_response = get_from_url(
api_data["modelFiles"]["ort"]["model"], json_response=False
)
if "classes" in api_data["modelFiles"]["ort"]:
save_text_lines_in_cache(
content=api_data["modelFiles"]["ort"]["classes"],
file="class_names.txt",
model_id=self.endpoint,
)
environment = {}
if "environment" in api_data["modelFiles"]["ort"]:
environment = get_from_url(api_data["modelFiles"]["ort"]["environment"])

save_bytes_in_cache(
content=model_weights_response.content,
file=self.weights_file,
Expand Down

0 comments on commit 54bc830

Please sign in to comment.