Skip to content

Commit

Permalink
Merge pull request #1036 from roboflow/fix/apply-confidence-threshold…
Browse files Browse the repository at this point in the history
…-for-classification-models

Apply confidence when inferring on classification models
  • Loading branch information
grzegorz-roboflow authored Feb 19, 2025
2 parents 3f8b654 + ec3397f commit b444481
Show file tree
Hide file tree
Showing 12 changed files with 22 additions and 5 deletions.
7 changes: 5 additions & 2 deletions inference/core/entities/responses/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,9 +262,12 @@ class ClassificationInferenceResponse(CvInferenceResponse, WithVisualizationResp
"""

predictions: List[ClassificationPrediction]
top: str = Field(description="The top predicted class label")
top: str = Field(
description="The top predicted class label", default=""
) # Not making this field optional to avoid breaking change - in other parts of the codebase `model_dump` is called with `exclude_none=True`
confidence: float = Field(
description="The confidence of the top predicted class label"
description="The confidence of the top predicted class label",
default=0.0,
)
parent_id: Optional[str] = Field(
description="Identifier of parent image region. Useful when stack of detection-models is in use to refer the RoI being the input to inference",
Expand Down
6 changes: 4 additions & 2 deletions inference/core/models/classification_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,8 @@ def make_response(
results = []
for i, cls_name in enumerate(self.class_names):
score = float(preds[i])
if score < confidence_threshold:
continue
pred = {
"class_id": i,
"class": cls_name,
Expand All @@ -363,8 +365,8 @@ def make_response(
width=img_dims[ind][1], height=img_dims[ind][0]
),
predictions=results,
top=results[0]["class"],
confidence=results[0]["confidence"],
top=results[0]["class"] if results else "",
confidence=results[0]["confidence"] if results else 0.0,
)
responses.append(response)

Expand Down
2 changes: 1 addition & 1 deletion inference/core/version.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.38.0"
__version__ = "0.39.0rc1"


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"name": "classifier",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"confidence": 0.09,
}
],
"outputs": [
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_multi_class_classification_workflow(
"name": "classifier",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"confidence": 0.5,
}
],
"outputs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
"name": "classifier",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"confidence": 0.09,
}
],
"outputs": [
Expand Down Expand Up @@ -98,6 +99,7 @@ def test_multi_class_classification_workflow(
"name": "classifier",
"image": "$inputs.image",
"model_id": "$inputs.model_id",
"confidence": 0.5,
}
],
"outputs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
],
"outputs": [
Expand Down Expand Up @@ -103,6 +104,7 @@ def test_legacy_detection_plus_classification_workflow_when_minimal_valid_input_
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
],
"outputs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
],
"outputs": [
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
{
"type": "roboflow_core/roboflow_dataset_upload@v2",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
{
"type": "DetectionsClassesReplacement",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
{
"type": "DetectionsClassesReplacement",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,6 +212,7 @@ def test_workflow_with_extraction_of_classes_for_detections(
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
{
"type": "PropertyDefinition",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"name": "breds_classification",
"image": "$steps.cropping.crops",
"model_id": "dog-breed-xpaq6/1",
"confidence": 0.09,
},
{
"type": "roboflow_core/continue_if@v1",
Expand Down

0 comments on commit b444481

Please sign in to comment.