From 4d19d4e56667ef97ee4ac8843cd784d7d2044365 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20P=C4=99czek?= Date: Thu, 6 Feb 2025 12:57:51 +0100 Subject: [PATCH] Fix problem with image slicer --- inference/core/workflows/core_steps/loader.py | 4 + .../transformations/image_slicer/v2.py | 196 +++++++++++++++ .../execution/test_workflow_with_sahi.py | 156 ++++++++++-- ...mage_slicer.py => test_image_slicer_v1.py} | 0 .../transformations/test_image_slicer_v2.py | 227 ++++++++++++++++++ 5 files changed, 564 insertions(+), 19 deletions(-) create mode 100644 inference/core/workflows/core_steps/transformations/image_slicer/v2.py rename tests/workflows/unit_tests/core_steps/transformations/{test_image_slicer.py => test_image_slicer_v1.py} (100%) create mode 100644 tests/workflows/unit_tests/core_steps/transformations/test_image_slicer_v2.py diff --git a/inference/core/workflows/core_steps/loader.py b/inference/core/workflows/core_steps/loader.py index b200f45daa..a6b59317b0 100644 --- a/inference/core/workflows/core_steps/loader.py +++ b/inference/core/workflows/core_steps/loader.py @@ -303,6 +303,9 @@ from inference.core.workflows.core_steps.transformations.image_slicer.v1 import ( ImageSlicerBlockV1, ) +from inference.core.workflows.core_steps.transformations.image_slicer.v2 import ( + ImageSlicerBlockV2, +) from inference.core.workflows.core_steps.transformations.perspective_correction.v1 import ( PerspectiveCorrectionBlockV1, ) @@ -611,6 +614,7 @@ def load_blocks() -> List[Type[WorkflowBlock]]: TwilioSMSNotificationBlockV1, GazeBlockV1, LlamaVisionBlockV1, + ImageSlicerBlockV2, ] diff --git a/inference/core/workflows/core_steps/transformations/image_slicer/v2.py b/inference/core/workflows/core_steps/transformations/image_slicer/v2.py new file mode 100644 index 0000000000..ae30f5407c --- /dev/null +++ b/inference/core/workflows/core_steps/transformations/image_slicer/v2.py @@ -0,0 +1,196 @@ +from dataclasses import replace +from typing import List, Literal, Optional, Tuple, Type, Union +from uuid import uuid4 + +import numpy as np +from pydantic import AliasChoices, ConfigDict, Field, PositiveInt +from supervision import crop_image +from typing_extensions import Annotated + +from inference.core.workflows.execution_engine.entities.base import ( + OutputDefinition, + WorkflowImageData, +) +from inference.core.workflows.execution_engine.entities.types import ( + FLOAT_ZERO_TO_ONE_KIND, + IMAGE_KIND, + INTEGER_KIND, + Selector, +) +from inference.core.workflows.prototypes.block import ( + BlockResult, + WorkflowBlock, + WorkflowBlockManifest, +) + +LONG_DESCRIPTION = """ +This block enables [Slicing Adaptive Inference (SAHI)](https://ieeexplore.ieee.org/document/9897990) technique in +Workflows providing implementation for first step of procedure - making slices out of input image. + +To use the block effectively, it must be paired with detection model (object-detection or +instance segmentation) running against output images from this block. At the end - +Detections Stitch block must be applied on top of predictions to merge them as if +the prediction was made against input image, not its slices. + +We recommend adjusting the size of slices to match the model's input size and the scale of objects in the dataset +the model was trained on. Models generally perform best on data that is similar to what they encountered during +training. The default size of slices is 640, but this might not be optimal if the model's input size is 320, as each +slice would be downsized by a factor of two during inference. Similarly, if the model's input size is 1280, each slice +will be artificially up-scaled. The best setup should be determined experimentally based on the specific data and model +you are using. + +To learn more about SAHI please visit [Roboflow blog](https://blog.roboflow.com/how-to-use-sahi-to-detect-small-objects/) +which describes the technique in details, yet not in context of Roboflow workflows. + +#### Changes compared to **v1** + +* All crops generated by slicer will be of equal size + +* No duplicated crops will be created +""" + + +class BlockManifest(WorkflowBlockManifest): + model_config = ConfigDict( + json_schema_extra={ + "name": "Image Slicer", + "version": "v2", + "short_description": "Tile the input image into a list of smaller images to perform small object detection.", + "long_description": LONG_DESCRIPTION, + "license": "Apache-2.0", + "block_type": "transformation", + "ui_manifest": { + "section": "advanced", + "icon": "fal fa-scissors", + "blockPriority": 9, + "opencv": True, + }, + } + ) + type: Literal["roboflow_core/image_slicer@v2"] + image: Selector(kind=[IMAGE_KIND]) = Field( + title="Image to slice", + description="The input image for this step.", + examples=["$inputs.image", "$steps.cropping.crops"], + validation_alias=AliasChoices("image", "images"), + ) + slice_width: Union[PositiveInt, Selector(kind=[INTEGER_KIND])] = Field( + default=640, + description="Width of each slice, in pixels", + examples=[320, "$inputs.slice_width"], + ) + slice_height: Union[PositiveInt, Selector(kind=[INTEGER_KIND])] = Field( + default=640, + description="Height of each slice, in pixels", + examples=[320, "$inputs.slice_height"], + ) + overlap_ratio_width: Union[ + Annotated[float, Field(ge=0.0, lt=1.0)], + Selector(kind=[FLOAT_ZERO_TO_ONE_KIND]), + ] = Field( + default=0.2, + description="Overlap ratio between consecutive slices in the width dimension", + examples=[0.2, "$inputs.overlap_ratio_width"], + ) + overlap_ratio_height: Union[ + Annotated[float, Field(ge=0.0, lt=1.0)], + Selector(kind=[FLOAT_ZERO_TO_ONE_KIND]), + ] = Field( + default=0.2, + description="Overlap ratio between consecutive slices in the height dimension", + examples=[0.2, "$inputs.overlap_ratio_height"], + ) + + @classmethod + def get_output_dimensionality_offset(cls) -> int: + return 1 + + @classmethod + def describe_outputs(cls) -> List[OutputDefinition]: + return [ + OutputDefinition(name="slices", kind=[IMAGE_KIND]), + ] + + @classmethod + def get_execution_engine_compatibility(cls) -> Optional[str]: + return ">=1.3.0,<2.0.0" + + +class ImageSlicerBlockV2(WorkflowBlock): + + @classmethod + def get_manifest(cls) -> Type[WorkflowBlockManifest]: + return BlockManifest + + def run( + self, + image: WorkflowImageData, + slice_width: int, + slice_height: int, + overlap_ratio_width: float, + overlap_ratio_height: float, + ) -> BlockResult: + image_numpy = image.numpy_image + resolution_wh = (image_numpy.shape[1], image_numpy.shape[0]) + offsets = generate_offsets( + resolution_wh=resolution_wh, + slice_wh=(slice_width, slice_height), + overlap_ratio_wh=(overlap_ratio_width, overlap_ratio_height), + ) + slices = [] + for offset in offsets: + x_min, y_min, _, _ = offset + crop_numpy = crop_image(image=image_numpy, xyxy=offset) + if crop_numpy.size: + cropped_image = WorkflowImageData.create_crop( + origin_image_data=image, + crop_identifier=f"image_slicer.{uuid4()}", + cropped_image=crop_numpy, + offset_x=x_min, + offset_y=y_min, + ) + slices.append({"slices": cropped_image}) + else: + slices.append({"slices": None}) + return slices + + +def generate_offsets( + resolution_wh: Tuple[int, int], + slice_wh: Tuple[int, int], + overlap_ratio_wh: Tuple[float, float], +) -> np.ndarray: + """ + This is modification of the function from block v1, which + makes sure that the "border" crops are pushed towards the center of + the image, making sure: + * all crops will be the same size + * deduplication of crops coordinates is done + """ + slice_width, slice_height = slice_wh + image_width, image_height = resolution_wh + slice_width = min(slice_width, image_width) + slice_height = min(slice_height, image_height) + overlap_width = int(overlap_ratio_wh[0] * slice_width) + overlap_height = int(overlap_ratio_wh[1] * slice_height) + width_stride = slice_width - overlap_width + height_stride = slice_height - overlap_height + ws = np.arange(0, image_width, width_stride) + ws_left_over = np.clip(ws + slice_width - image_width, 0, slice_width) + hs = np.arange(0, image_height, height_stride) + hs_left_over = np.clip(hs + slice_height - image_height, 0, slice_height) + anchors_ws = ws - ws_left_over + anchors_hs = hs - hs_left_over + xmin, ymin = np.meshgrid(anchors_ws, anchors_hs) + xmax = np.clip(xmin + slice_width, 0, image_width) + ymax = np.clip(ymin + slice_height, 0, image_height) + results = np.stack([xmin, ymin, xmax, ymax], axis=-1).reshape(-1, 4) + deduplicated_results = [] + already_seen = set() + for xyxy in results: + xyxy_tuple = tuple(xyxy) + if xyxy_tuple in already_seen: + continue + deduplicated_results.append(xyxy) + already_seen.add(xyxy_tuple) + return np.array(deduplicated_results) diff --git a/tests/workflows/integration_tests/execution/test_workflow_with_sahi.py b/tests/workflows/integration_tests/execution/test_workflow_with_sahi.py index 60affd570b..67d6b5f89c 100644 --- a/tests/workflows/integration_tests/execution/test_workflow_with_sahi.py +++ b/tests/workflows/integration_tests/execution/test_workflow_with_sahi.py @@ -8,6 +8,7 @@ from inference.core.entities.requests.inference import ObjectDetectionInferenceRequest from inference.core.env import WORKFLOWS_MAX_CONCURRENT_STEPS from inference.core.managers.base import ModelManager +from inference.core.utils.drawing import create_tiles from inference.core.workflows.core_steps.common.entities import StepExecutionMode from inference.core.workflows.execution_engine.core import ExecutionEngine from tests.workflows.integration_tests.execution.workflows_gallery_collector.decorators import ( @@ -62,25 +63,6 @@ } -@add_to_workflows_gallery( - category="Advanced inference techniques", - use_case_title="SAHI in workflows - object detection", - use_case_description=""" -This example illustrates usage of [SAHI](https://blog.roboflow.com/how-to-use-sahi-to-detect-small-objects/) -technique in workflows. - -Workflows implementation requires three blocks: - -- Image Slicer - which runs a sliding window over image and for each image prepares batch of crops - -- detection model block (in our scenario Roboflow Object Detection model) - which is responsible -for making predictions on each crop - -- Detections stitch - which combines partial predictions for each slice of the image into a single prediction - """, - workflow_definition=SAHI_WORKFLOW, - workflow_name_in_app="sahi-detection", -) def test_sahi_workflow_with_none_as_filtering_strategy( model_manager: ModelManager, license_plate_image: np.ndarray, @@ -162,6 +144,142 @@ def test_sahi_workflow_with_none_as_filtering_strategy( ), "Expected boxes for second image to be exactly as measured during test creation" +SAHI_WORKFLOW_SLICER_V2 = { + "version": "1.0.0", + "inputs": [ + {"type": "WorkflowImage", "name": "image"}, + {"type": "WorkflowParameter", "name": "overlap_filtering_strategy"}, + {"type": "WorkflowParameter", "name": "slice_width", "default_value": 128}, + {"type": "WorkflowParameter", "name": "slice_height", "default_value": 128}, + {"type": "WorkflowParameter", "name": "slice_overlap", "default_value": 0.1}, + ], + "steps": [ + { + "type": "roboflow_core/image_slicer@v2", + "name": "image_slicer", + "image": "$inputs.image", + "slice_width": "$inputs.slice_width", + "slice_height": "$inputs.slice_height", + "slice_overlap": "$inputs.slice_overlap", + }, + { + "type": "roboflow_core/roboflow_object_detection_model@v2", + "name": "detection", + "image": "$steps.image_slicer.slices", + "model_id": "yolov8n-640", + }, + { + "type": "roboflow_core/detections_stitch@v1", + "name": "stitch", + "reference_image": "$inputs.image", + "predictions": "$steps.detection.predictions", + "overlap_filtering_strategy": "$inputs.overlap_filtering_strategy", + }, + { + "type": "roboflow_core/bounding_box_visualization@v1", + "name": "bbox_visualiser", + "predictions": "$steps.stitch.predictions", + "image": "$inputs.image", + }, + ], + "outputs": [ + { + "type": "JsonField", + "name": "predictions", + "selector": "$steps.stitch.predictions", + "coordinates_system": "own", + }, + { + "type": "JsonField", + "name": "slices", + "selector": "$steps.image_slicer.slices", + }, + { + "type": "JsonField", + "name": "visualisation", + "selector": "$steps.bbox_visualiser.image", + }, + ], +} + + +@add_to_workflows_gallery( + category="Advanced inference techniques", + use_case_title="SAHI in workflows - object detection", + use_case_description=""" +This example illustrates usage of [SAHI](https://blog.roboflow.com/how-to-use-sahi-to-detect-small-objects/) +technique in workflows. + +Workflows implementation requires three blocks: + +- Image Slicer - which runs a sliding window over image and for each image prepares batch of crops + +- detection model block (in our scenario Roboflow Object Detection model) - which is responsible +for making predictions on each crop + +- Detections stitch - which combines partial predictions for each slice of the image into a single prediction + """, + workflow_definition=SAHI_WORKFLOW, + workflow_name_in_app="sahi-detection", +) +def test_sahi_workflow_with_slicer_v2( + model_manager: ModelManager, + crowd_image: np.ndarray, +) -> None: + """ + In this test we check how all blocks that form SAHI technique behave. + Blocks involved in tests: + - "roboflow_core/image_slicer@v2" from inference.core.workflows.core_steps.transformations.image_slicer.v2 + - "roboflow_core/detections_stitch@v1", from inference.core.workflows.core_steps.fusion.detections_stitch.v1 + + This scenario covers usage of SAHI when overlapping predictions are not post-processed. + """ + # given + workflow_init_parameters = { + "workflows_core.model_manager": model_manager, + "workflows_core.step_execution_mode": StepExecutionMode.LOCAL, + } + execution_engine = ExecutionEngine.init( + workflow_definition=SAHI_WORKFLOW_SLICER_V2, + init_parameters=workflow_init_parameters, + max_concurrent_steps=WORKFLOWS_MAX_CONCURRENT_STEPS, + ) + + # when + result = execution_engine.run( + runtime_parameters={ + "image": crowd_image, + "overlap_filtering_strategy": "nms", + } + ) + + # then + assert np.allclose( + result[0]["predictions"].xyxy, + np.array( + [ + [103, 103, 113, 124], + [182, 272, 231, 334], + [114, 270, 144, 334], + [271, 267, 329, 334], + [226, 288, 246, 329], + [240, 251, 251, 283], + [249, 251, 261, 284], + [388, 264, 413, 334], + [309, 265, 318, 297], + [359, 260, 374, 291], + [323, 257, 345, 318], + [342, 260, 361, 321], + [415, 259, 457, 334], + [552, 260, 597, 334], + [522, 257, 557, 334], + [158, 297, 181, 348], + ] + ), + atol=2, + ), "Expected boxes for first image to be exactly as measured during test creation" + + def test_sahi_workflow_with_nms_as_filtering_strategy( model_manager: ModelManager, license_plate_image: np.ndarray, diff --git a/tests/workflows/unit_tests/core_steps/transformations/test_image_slicer.py b/tests/workflows/unit_tests/core_steps/transformations/test_image_slicer_v1.py similarity index 100% rename from tests/workflows/unit_tests/core_steps/transformations/test_image_slicer.py rename to tests/workflows/unit_tests/core_steps/transformations/test_image_slicer_v1.py diff --git a/tests/workflows/unit_tests/core_steps/transformations/test_image_slicer_v2.py b/tests/workflows/unit_tests/core_steps/transformations/test_image_slicer_v2.py new file mode 100644 index 0000000000..933d063e34 --- /dev/null +++ b/tests/workflows/unit_tests/core_steps/transformations/test_image_slicer_v2.py @@ -0,0 +1,227 @@ +import numpy as np +import pytest +from pydantic import ValidationError + +from inference.core.workflows.core_steps.transformations.image_slicer.v2 import ( + BlockManifest, + ImageSlicerBlockV2, +) +from inference.core.workflows.execution_engine.entities.base import ( + ImageParentMetadata, + OriginCoordinatesSystem, + WorkflowImageData, +) + + +@pytest.mark.parametrize("image_alias", ["image", "images"]) +def test_manifest_v1_parsing_when_valid_input_given(image_alias: str) -> None: + # given + raw_manifest = { + "type": "roboflow_core/image_slicer@v2", + "name": "slicer", + image_alias: "$inputs.image", + "slice_width": 100, + "slice_height": 200, + "overlap_ratio_width": 0.2, + "overlap_ratio_height": 0.3, + } + + # when + result = BlockManifest.model_validate(raw_manifest) + + # then + assert result == BlockManifest( + name="slicer", + type="roboflow_core/image_slicer@v2", + image="$inputs.image", + slice_width=100, + slice_height=200, + overlap_ratio_width=0.2, + overlap_ratio_height=0.3, + ) + + +@pytest.mark.parametrize("field_to_delete", ["image", "type", "name"]) +def test_manifest_v1_parsing_when_required_field_missing(field_to_delete: str) -> None: + # given + raw_manifest = { + "type": "roboflow_core/image_slicer@v2", + "name": "slicer", + "image": "$inputs.image", + "slice_width": 100, + "slice_height": 200, + "overlap_ratio_width": 0.2, + "overlap_ratio_height": 0.3, + } + del raw_manifest[field_to_delete] + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(raw_manifest) + + +def test_manifest_v1_parsing_when_slice_width_outside_of_range() -> None: + # given + raw_manifest = { + "type": "roboflow_core/image_slicer@v2", + "name": "slicer", + "image": "$inputs.image", + "slice_width": -1, + "slice_height": 200, + "overlap_ratio_width": 0.2, + "overlap_ratio_height": 0.3, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(raw_manifest) + + +def test_manifest_v1_parsing_when_slice_height_outside_of_range() -> None: + # given + raw_manifest = { + "type": "roboflow_core/image_slicer@v2", + "name": "slicer", + "image": "$inputs.image", + "slice_width": 200, + "slice_height": -1, + "overlap_ratio_width": 0.2, + "overlap_ratio_height": 0.3, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(raw_manifest) + + +@pytest.mark.parametrize( + "overlap_property", ["overlap_ratio_width", "overlap_ratio_height"] +) +@pytest.mark.parametrize("invalid_value", [-0.1, 1.0, 1.0]) +def test_manifest_v1_parsing_when_overlap_outside_of_range( + overlap_property: str, invalid_value: float +) -> None: + # given + raw_manifest = { + "type": "roboflow_core/image_slicer@v2", + "name": "slicer", + "image": "$inputs.image", + "slice_width": 200, + "slice_height": 200, + overlap_property: invalid_value, + } + + # when + with pytest.raises(ValidationError): + _ = BlockManifest.model_validate(raw_manifest) + + +def test_running_block() -> None: + # given + image = WorkflowImageData( + numpy_image=np.zeros((256, 512, 3), dtype=np.uint8), + parent_metadata=ImageParentMetadata(parent_id="parent"), + ) + block = ImageSlicerBlockV2() + + # when + result = block.run( + image=image, + slice_width=200, + slice_height=100, + overlap_ratio_width=0.1, + overlap_ratio_height=0.2, + ) + + # then + assert len(result) == 9, "Expected exactly 4 crops" + for i in range(9): + assert result[i]["slices"].parent_metadata.parent_id.startswith( + "image_slicer." + ), f"Expected parent to be set properly for {i}th crop" + + assert result[0]["slices"].numpy_image.shape == (100, 200, 3) + assert result[0][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=0, left_top_y=0, origin_width=512, origin_height=256 + ), "Expected 1st crop to have the following coordinates regarding root" + assert result[1]["slices"].numpy_image.shape == (100, 200, 3) + assert result[1][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=180, left_top_y=0, origin_width=512, origin_height=256 + ), "Expected 2nd crop to have the following coordinates regarding root" + assert result[2]["slices"].numpy_image.shape == (100, 200, 3) + assert result[2][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=312, left_top_y=0, origin_width=512, origin_height=256 + ), "Expected 3rd crop to have the following coordinates regarding root" + assert result[3]["slices"].numpy_image.shape == (100, 200, 3) + assert result[3][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=0, left_top_y=80, origin_width=512, origin_height=256 + ), "Expected 4th crop to have the following coordinates regarding root" + assert result[4]["slices"].numpy_image.shape == (100, 200, 3) + assert result[4][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=180, left_top_y=80, origin_width=512, origin_height=256 + ), "Expected 5th crop to have the following coordinates regarding root" + assert result[5]["slices"].numpy_image.shape == (100, 200, 3) + assert result[5][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=312, left_top_y=80, origin_width=512, origin_height=256 + ), "Expected 6th crop to have the following coordinates regarding root" + assert result[6]["slices"].numpy_image.shape == (100, 200, 3) + assert result[6][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=0, left_top_y=156, origin_width=512, origin_height=256 + ), "Expected 7th crop to have the following coordinates regarding root" + assert result[7]["slices"].numpy_image.shape == (100, 200, 3) + assert result[7][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=180, left_top_y=156, origin_width=512, origin_height=256 + ), "Expected 8th crop to have the following coordinates regarding root" + assert result[8]["slices"].numpy_image.shape == (100, 200, 3) + assert result[8][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=312, left_top_y=156, origin_width=512, origin_height=256 + ), "Expected 9th crop to have the following coordinates regarding root" + + +def test_running_block_when_slice_size_exceed_image_size() -> None: + # given + image = WorkflowImageData( + numpy_image=np.zeros((256, 512, 3), dtype=np.uint8), + parent_metadata=ImageParentMetadata(parent_id="parent"), + ) + block = ImageSlicerBlockV2() + + # when + result = block.run( + image=image, + slice_width=2000, + slice_height=1000, + overlap_ratio_width=0.1, + overlap_ratio_height=0.2, + ) + + # then + assert len(result) == 1, "Expected exactly 1 crop" + assert result[0]["slices"].parent_metadata.parent_id.startswith( + "image_slicer." + ), f"Expected parent to be set properly for 1st crop" + + assert result[0]["slices"].numpy_image.shape == (256, 512, 3) + assert result[0][ + "slices" + ].workflow_root_ancestor_metadata.origin_coordinates == OriginCoordinatesSystem( + left_top_x=0, left_top_y=0, origin_width=512, origin_height=256 + ), "Expected 1st crop to have the following coordinates regarding root"