-
Notifications
You must be signed in to change notification settings - Fork 45.7k
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Showing
2 changed files
with
270 additions
and
0 deletions.
There are no files selected for viewing
115 changes: 115 additions & 0 deletions
115
.../waste_identification_ml/docker_solution/prediction_pipeline/prediction_postprocessing.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
"""This script is tailored for processing outputs from two Mask R-CNN models. | ||
It is designed to handle object detection and segmentation tasks, combines | ||
outputs from two Mask R-CNN models. This involves aggregating detected objects | ||
and their respective masks and bounding boxes. Identifies and removes duplicate | ||
detections in the merged result, ensuring each detected object is unique. | ||
Extracts and compiles features of the detected objects, which may include | ||
aspects like size, area, color, or other model-specific attributes. | ||
""" | ||
|
||
import sys | ||
import numpy as np | ||
|
||
sys.path.append( | ||
'models/official/projects/waste_identification_ml/model_inference/' | ||
) | ||
from official.projects.waste_identification_ml.model_inference import postprocessing # pylint: disable=g-import-not-at-top,g-bad-import-order | ||
|
||
HEIGHT, WIDTH = 512, 1024 | ||
|
||
|
||
def merge_predictions( | ||
results: list[dict[str, np.ndarray]], | ||
score: float, | ||
category_indices: list[list[str]], | ||
category_index: dict[int, dict[str, str]], | ||
max_detection: int, | ||
) -> dict[str, np.ndarray]: | ||
"""Merges and refines prediction results. | ||
This function takes the prediction results from two models, reframes masks to | ||
the original image size, and aligns similar masks from both model outputs. It | ||
then merges these masks into a single result based on the given threshold | ||
criteria. The criteria include a minimum score threshold, an area threshold, | ||
and category alignment using provided indices and dictionary. | ||
Args: | ||
results: Outputs from 2 Mask RCNN models. | ||
score: The minimum score threshold for filtering out the detections. | ||
category_indices: Class labels of 2 models. | ||
category_index: A dictionary mapping class IDs to class labels. | ||
max_detection: Maximum number of detections from both models. | ||
Returns: | ||
Merged and filtered detection results. | ||
""" | ||
# This threshold will be used to eliminate all the detected objects whose | ||
# area is greater than the 'area_threshold'. | ||
area_threshold = 0.3 * HEIGHT * WIDTH | ||
|
||
# Reframe the masks from the output of the model to its original size. | ||
results_reframed = [ | ||
postprocessing.reframing_masks(detection, HEIGHT, WIDTH) | ||
for detection in results | ||
] | ||
|
||
# Align similar masks from both the model outputs and merge all the | ||
# properties into a single mask. Function will only compare first | ||
# 'max_detection' objects. All the objects which have less than | ||
# 'score' probability will be eliminated. All objects whose area is | ||
# more than 'area_threshold' will be eliminated. 'category_dict' and | ||
# 'category_index' are used to find the label from the combinations of | ||
# labels from both individual models. The output should include masks | ||
# appearing in either of the models if they qualify the criteria. | ||
final_result = postprocessing.find_similar_masks( | ||
results_reframed[0], | ||
results_reframed[1], | ||
max_detection, | ||
score, | ||
category_indices, | ||
category_index, | ||
area_threshold, | ||
) | ||
return final_result | ||
|
||
|
||
def _transform_bounding_boxes( | ||
results: dict[str, np.ndarray] | ||
) -> list[list[int]]: | ||
"""Transforms normalized bounding box coordinates to their original format. | ||
This function takes a dictionary containing normalized bounding box | ||
coordinates and transforms these coordinates to their original scale based on | ||
the provided image height and width. | ||
Args: | ||
results: A dictionary containing detection results. Expected to have a key | ||
'detection_boxes' with a numpy array of normalized coordinates. | ||
Returns: | ||
A list of transformed bounding boxes, each represented as [ymin, xmin, ymax, | ||
xmax] in the original image scale. | ||
""" | ||
transformed_boxes = [] | ||
for bb in results['detection_boxes'][0]: | ||
ymin = int(bb[0] * HEIGHT) | ||
xmin = int(bb[1] * WIDTH) | ||
ymax = int(bb[2] * HEIGHT) | ||
xmax = int(bb[3] * WIDTH) | ||
transformed_boxes.append([ymin, xmin, ymax, xmax]) | ||
return transformed_boxes |
155 changes: 155 additions & 0 deletions
155
...e_identification_ml/docker_solution/prediction_pipeline/prediction_postprocessing_test.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,155 @@ | ||
# Copyright 2024 The TensorFlow Authors. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
|
||
import unittest | ||
from unittest import mock | ||
import numpy as np | ||
from official.projects.waste_identification_ml.docker_solution.prediction_pipeline import prediction_postprocessing | ||
|
||
|
||
class PostprocessingTest(unittest.TestCase): | ||
|
||
def setUp(self): | ||
super().setUp() | ||
self.results1 = { | ||
'detection_boxes': [np.array([[0, 0, 100, 100], [100, 100, 200, 200]])], | ||
'detection_masks': [ | ||
np.zeros((1, 512, 1024), dtype=np.uint8), | ||
np.ones((1, 512, 1024), dtype=np.uint8), | ||
], | ||
'detection_scores': [[0.9, 0.8]], | ||
'detection_classes': [1, 2], | ||
'detection_classes_names': ['class_1', 'class_2'], | ||
} | ||
|
||
self.results2 = { | ||
'detection_boxes': [ | ||
np.array([[50, 50, 150, 150], [150, 150, 250, 250]]) | ||
], | ||
'detection_masks': [ | ||
np.full((1, 512, 1024), 0.5, dtype=np.uint8), | ||
np.full((1, 512, 1024), 0.5, dtype=np.uint8), | ||
], | ||
'detection_scores': [[0.9, 0.8]], | ||
'detection_classes': [2, 1], | ||
'detection_classes_names': ['class_2', 'class_1'], | ||
} | ||
|
||
self.category_indices = [[1, 2], [2, 1]] | ||
|
||
self.category_index = { | ||
1: {'id': 1, 'name': 'class_1'}, | ||
2: {'id': 2, 'name': 'class_2'}, | ||
} | ||
self.height = 512 | ||
self.width = 1024 | ||
|
||
def test_merge_predictions(self): | ||
results = prediction_postprocessing.merge_predictions( | ||
[self.results1, self.results2], | ||
0.8, | ||
self.category_indices, | ||
self.category_index, | ||
4, | ||
) | ||
|
||
self.assertEqual(results['num_detections'], 4) | ||
self.assertEqual(results['detection_scores'].shape, (4,)) | ||
self.assertEqual(results['detection_boxes'].shape, (4, 4)) | ||
self.assertEqual(results['detection_classes'].shape, (4,)) | ||
self.assertEqual( | ||
results['detection_classes_names'], | ||
['class_1', 'class_2', 'class_1', 'class_2'], | ||
) | ||
self.assertEqual(results['detection_masks_reframed'].shape, (4, 512, 1024)) | ||
|
||
@mock.patch('postprocessing.find_similar_masks') | ||
def test_merge_predictions_calls_find_similar_masks( | ||
self, mock_find_similar_masks | ||
): | ||
prediction_postprocessing.merge_predictions( | ||
[self.results1, self.results2], | ||
0.8, | ||
self.category_indices, | ||
self.category_index, | ||
4, | ||
) | ||
|
||
mock_find_similar_masks.assert_called_once_with( | ||
self.results1, | ||
self.results2, | ||
4, | ||
0.8, | ||
self.category_indices, | ||
self.category_index, | ||
0.3 * 512 * 1024, | ||
) | ||
|
||
def test_merge_predictions_with_empty_results(self): | ||
results = prediction_postprocessing.merge_predictions( | ||
[{}, {}], | ||
0.8, | ||
self.category_indices, | ||
self.category_index, | ||
4, | ||
) | ||
|
||
self.assertEqual(results['num_detections'], 0) | ||
self.assertEqual(results['detection_scores'].shape, (0,)) | ||
self.assertEqual(results['detection_boxes'].shape, (0, 4)) | ||
self.assertEqual(results['detection_classes'].shape, (0,)) | ||
self.assertEqual(results['detection_classes_names'], []) | ||
self.assertEqual(results['detection_masks_reframed'].shape, (0, 512, 1024)) | ||
|
||
def test_merge_predictions_with_invalid_category_indices(self): | ||
category_indices = [[1, 3], [2, 4]] | ||
|
||
with self.assertRaises(ValueError): | ||
prediction_postprocessing.merge_predictions( | ||
[self.results1, self.results2], | ||
0.8, | ||
category_indices, | ||
self.category_index, | ||
4, | ||
) | ||
|
||
def test_transform_bounding_boxes(self): | ||
results = { | ||
'detection_boxes': np.array([[ | ||
[0.1, 0.2, 0.4, 0.5], # Normalized coordinates | ||
[0.3, 0.3, 0.6, 0.7], | ||
]]) | ||
} | ||
|
||
# Expected output for the adjusted height and width | ||
expected_transformed_boxes = [ | ||
[ | ||
int(0.1 * self.height), | ||
int(0.2 * self.width), | ||
int(0.4 * self.height), | ||
int(0.5 * self.width), | ||
], | ||
[ | ||
int(0.3 * self.height), | ||
int(0.3 * self.width), | ||
int(0.6 * self.height), | ||
int(0.7 * self.width), | ||
], | ||
] | ||
|
||
transformed_boxes = prediction_postprocessing._transform_bounding_boxes( | ||
results | ||
) | ||
|
||
self.assertEqual(transformed_boxes, expected_transformed_boxes) |