Skip to content

Commit

Permalink
feat: correct object detection metrics (#3490)
Browse files Browse the repository at this point in the history
This PR:
- fixes an issue that made it impossible to compute OD metrics
- ads per-class object detection metrics
  • Loading branch information
pawel-kmiecik authored Aug 7, 2024
1 parent 24a1f29 commit eba12da
Show file tree
Hide file tree
Showing 5 changed files with 311 additions and 119 deletions.
5 changes: 4 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,14 +1,17 @@
## 0.15.2-dev2
## 0.15.2-dev3

### Enhancements

### Features

* **Added per-class Object Detection metrics in the evaluation**. The metrics include average precision, precision, recall, and f1-score for each class in the dataset.

### Fixes

* **Renames Astra to Astra DB** Conforms with DataStax internal naming conventions.
* **Accommodate single-column CSV files.** Resolves a limitation of `partition_csv()` where delimiter detection would fail on a single-column CSV file (which naturally has no delimeters).
* **Accommodate `image/jpg` in PPTX as alias for `image/jpeg`.** Resolves problem partitioning PPTX files having an invalid `image/jpg` (should be `image/jpeg`) MIME-type in the `[Content_Types].xml` member of the PPTX Zip archive.
* **Fixes an issue in Object Detection metrics** The issue was in preprocessing/validating the ground truth and predicted data for object detection metrics.

## 0.15.1

Expand Down
2 changes: 1 addition & 1 deletion unstructured/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.15.2-dev2" # pragma: no cover
__version__ = "0.15.2-dev3" # pragma: no cover
16 changes: 13 additions & 3 deletions unstructured/ingest/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@

from unstructured.metrics.evaluate import (
ElementTypeMetricsCalculator,
ObjectDetectionMetricsCalculator,
ObjectDetectionAggregatedMetricsCalculator,
ObjectDetectionPerClassMetricsCalculator,
TableStructureMetricsCalculator,
TextExtractionMetricsCalculator,
filter_metrics,
Expand Down Expand Up @@ -291,14 +292,23 @@ def measure_object_detection_metrics_command(
output_list: Optional[List[str]] = None,
source_list: Optional[List[str]] = None,
):
return (
ObjectDetectionMetricsCalculator(
aggregated_df = (
ObjectDetectionAggregatedMetricsCalculator(
documents_dir=output_dir,
ground_truths_dir=source_dir,
)
.on_files(document_paths=output_list, ground_truth_paths=source_list)
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
)
per_class_df = (
ObjectDetectionPerClassMetricsCalculator(
documents_dir=output_dir,
ground_truths_dir=source_dir,
)
.on_files(document_paths=output_list, ground_truth_paths=source_list)
.calculate(export_dir=export_dir, visualize_progress=visualize, display_agg_df=True)
)
return aggregated_df, per_class_df


@main.command()
Expand Down
179 changes: 157 additions & 22 deletions unstructured/metrics/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import concurrent.futures
import json
import logging
import os
import sys
Expand All @@ -18,7 +19,9 @@
calculate_element_type_percent_match,
get_element_type_frequency,
)
from unstructured.metrics.object_detection import ObjectDetectionEvalProcessor
from unstructured.metrics.object_detection import (
ObjectDetectionEvalProcessor,
)
from unstructured.metrics.table.table_eval import TableEvalProcessor
from unstructured.metrics.text_extraction import calculate_accuracy, calculate_percent_missing_text
from unstructured.metrics.utils import (
Expand Down Expand Up @@ -68,10 +71,14 @@ def __post_init__(self):

# -- auto-discover all files in the directories --
self._document_paths = [
path.relative_to(self.documents_dir) for path in self.documents_dir.rglob("*")
path.relative_to(self.documents_dir)
for path in self.documents_dir.glob("*")
if path.is_file()
]
self._ground_truth_paths = [
path.relative_to(self.ground_truths_dir) for path in self.ground_truths_dir.rglob("*")
path.relative_to(self.ground_truths_dir)
for path in self.ground_truths_dir.glob("*")
if path.is_file()
]

@property
Expand Down Expand Up @@ -147,7 +154,13 @@ def calculate(
def _default_executor(cls):
max_processors = int(os.environ.get("MAX_PROCESSES", os.cpu_count()))
logger.info(f"Configuring a pool of {max_processors} processors for parallel processing.")
return concurrent.futures.ProcessPoolExecutor(max_workers=max_processors)
return cls._get_executor_class()(max_workers=max_processors)

@classmethod
def _get_executor_class(
cls,
) -> type[concurrent.futures.ThreadPoolExecutor] | type[concurrent.futures.ProcessPoolExecutor]:
return concurrent.futures.ProcessPoolExecutor

def _process_all_documents(
self, executor: concurrent.futures.Executor, visualize_progress: bool
Expand Down Expand Up @@ -336,6 +349,17 @@ def _validate_inputs(self):
"Specified file type under `documents_dir` or `output_list` should be one of "
f"`json` or `txt`. The given file type is {self.document_type}, exiting."
)
for path in self._document_paths:
try:
path.suffixes[-1]
except IndexError:
logger.error(f"File {path} does not have a suffix, skipping")
continue
if path.suffixes[-1] != f".{self.document_type}":
logger.warning(
"The directory contains file type inconsistent with the given input. "
"Please note that some files will be skipped."
)
if not all(path.suffixes[-1] == f".{self.document_type}" for path in self._document_paths):
logger.warning(
"The directory contains file type inconsistent with the given input. "
Expand Down Expand Up @@ -598,7 +622,7 @@ def filter_metrics(


@dataclass
class ObjectDetectionMetricsCalculator(BaseMetricsCalculator):
class ObjectDetectionMetricsCalculatorBase(BaseMetricsCalculator, ABC):
"""
Calculates object detection metrics for each document:
- f1 score
Expand All @@ -613,6 +637,7 @@ def __post_init__(self):
self._document_paths = [
path.relative_to(self.documents_dir)
for path in self.documents_dir.rglob("analysis/*/layout_dump/object_detection.json")
if path.is_file()
]

@property
Expand Down Expand Up @@ -643,8 +668,9 @@ def _find_file_in_ground_truth(self, file_stem: str) -> Optional[Path]:
return path
return None

def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate metrics for a single document.
def _get_paths(self, doc: Path) -> tuple(str, Path, Path):
"""Resolves ground doctype, prediction file path and ground truth path.
As OD dump directory structure differes from other simple outputs, it needs
a specific processing to match the output OD dump file with corresponding
OD GT file.
Expand All @@ -667,39 +693,29 @@ def _process_document(self, doc: Path) -> Optional[list]:
doc (Path): path to the OD dump file
Returns:
list: a list of metrics (representing a single row) for a single document
tuple: doctype, prediction file path, ground truth path
"""
od_dump_path = Path(doc)
file_stem = od_dump_path.parts[-3] # we take the `document_name` - so the filename stem

src_gt_filename = self._find_file_in_ground_truth(file_stem)

if src_gt_filename not in self._ground_truth_paths:
return None
raise ValueError(f"Ground truth file {src_gt_filename} not found in list of GT files")

doctype = Path(src_gt_filename.stem).suffix[1:]

prediction_file = self.documents_dir / doc
if not prediction_file.exists():
logger.warning(f"Prediction file {prediction_file} does not exist, skipping")
return None
raise ValueError(f"Prediction file {prediction_file} does not exist")

ground_truth_file = self.ground_truths_dir / src_gt_filename
if not ground_truth_file.exists():
logger.warning(f"Ground truth file {ground_truth_file} does not exist, skipping")
return None
raise ValueError(f"Ground truth file {ground_truth_file} does not exist")

processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
metrics = processor.get_metrics()

return [
src_gt_filename.stem,
doctype,
None, # connector
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
return doctype, prediction_file, ground_truth_file

def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
headers = ["filename", "doctype", "connector"] + self.supported_metric_names
Expand All @@ -722,3 +738,122 @@ def _generate_dataframes(self, rows) -> tuple[pd.DataFrame, pd.DataFrame]:
agg_df.columns = AGG_HEADERS

return df, agg_df


class ObjectDetectionPerClassMetricsCalculator(ObjectDetectionMetricsCalculatorBase):

def __post_init__(self):
super().__post_init__()
self.per_class_metric_names: list[str] | None = None
self._set_supported_metrics()

@property
def supported_metric_names(self):
if self.per_class_metric_names:
return self.per_class_metric_names
else:
raise ValueError("per_class_metrics not initialized - cannot get class names")

@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics-per-class.tsv"

@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics-per-class.tsv"

def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate both class-aggregated and per-class metrics for a single document.
Args:
doc (Path): path to the OD dump file
Returns:
tuple: a tuple of aggregated and per-class metrics for a single document
"""
try:
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
except ValueError as e:
logger.error(f"Failed to process document {doc}: {e}")
return None

processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
_, per_class_metrics = processor.get_metrics()

per_class_metrics_row = [
ground_truth_file.stem,
doctype,
None, # connector
]

for combined_metric_name in self.supported_metric_names:
metric = "_".join(combined_metric_name.split("_")[:-1])
class_name = combined_metric_name.split("_")[-1]
class_metrics = getattr(per_class_metrics, metric)
per_class_metrics_row.append(class_metrics[class_name])
return per_class_metrics_row

def _set_supported_metrics(self):
"""Sets the supported metrics based on the classes found in the ground truth files.
The difference between per class and aggregated calculator is that the list of classes
(so the metrics) bases on the contents of the GT / prediction files.
"""
metrics = ["f1_score", "precision", "recall", "m_ap"]
classes = set()
for gt_file in self._ground_truth_paths:
gt_file_path = self.ground_truths_dir / gt_file
with open(gt_file_path) as f:
gt = json.load(f)
gt_classes = gt["object_detection_classes"]
classes.update(gt_classes)
per_class_metric_names = []
for metric in metrics:
for class_name in classes:
per_class_metric_names.append(f"{metric}_{class_name}")
self.per_class_metric_names = sorted(per_class_metric_names)


class ObjectDetectionAggregatedMetricsCalculator(ObjectDetectionMetricsCalculatorBase):
"""Calculates object detection metrics for each document and aggregates by all classes"""

@property
def supported_metric_names(self):
return ["f1_score", "precision", "recall", "m_ap"]

@property
def default_tsv_name(self):
return "all-docs-object-detection-metrics.tsv"

@property
def default_agg_tsv_name(self):
return "aggregate-object-detection-metrics.tsv"

def _process_document(self, doc: Path) -> Optional[list]:
"""Calculate both class-aggregated and per-class metrics for a single document.
Args:
doc (Path): path to the OD dump file
Returns:
list: a list of aggregated metrics for a single document
"""
try:
doctype, prediction_file, ground_truth_file = self._get_paths(doc)
except ValueError as e:
logger.error(f"Failed to process document {doc}: {e}")
return None

processor = ObjectDetectionEvalProcessor.from_json_files(
prediction_file_path=prediction_file,
ground_truth_file_path=ground_truth_file,
)
metrics, _ = processor.get_metrics()

return [
ground_truth_file.stem,
doctype,
None, # connector
] + [getattr(metrics, metric) for metric in self.supported_metric_names]
Loading

0 comments on commit eba12da

Please sign in to comment.