From 80a710565ca23c58cc73c93a1ed69c1991512eef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Thu, 5 Dec 2024 09:59:22 +0100 Subject: [PATCH 1/2] fix(nutrisight): improve post-processing - fix postprocessing bug for addition of unit - correct OCR error for serving_size when 'g' was mistaken as '9' --- robotoff/prediction/nutrition_extraction.py | 16 +++++- .../prediction/test_nutrition_extraction.py | 57 +++++++++++++++++++ 2 files changed, 70 insertions(+), 3 deletions(-) diff --git a/robotoff/prediction/nutrition_extraction.py b/robotoff/prediction/nutrition_extraction.py index 145f105e70..40007d70aa 100644 --- a/robotoff/prediction/nutrition_extraction.py +++ b/robotoff/prediction/nutrition_extraction.py @@ -416,6 +416,9 @@ def postprocess_aggregated_entities( return postprocessed_entities +SERVING_SIZE_MISSING_G = re.compile(r"([0-9]+[,.]?[0-9]*)\s*9") + + def postprocess_aggregated_entities_single(entity: JSONType) -> JSONType: """Postprocess a single aggregated entity and return an entity with the extracted information. This is the first step in the postprocessing of aggregated entities. @@ -466,6 +469,11 @@ def postprocess_aggregated_entities_single(entity: JSONType) -> JSONType: if entity_label == "serving_size": value = words_str + # Sometimes the unit 'g' in the `serving_size is detected as a '9' + # In such cases, we replace the '9' with 'g' + match = SERVING_SIZE_MISSING_G.match(value) + if match: + value = f"{match.group(1)} g" elif words_str in ("trace", "traces"): value = "traces" else: @@ -549,13 +557,15 @@ def match_nutrient_value( for target in ( "proteins", "sugars", - "added-sugars", "carbohydrates", "fat", - "saturated-fat", "fiber", "salt", - "trans-fat", + # we use "_" here as separator as '-' is only used in + # Product Opener, the label names are all separated by '_' + "saturated_fat", + "added_sugars", + "trans_fat", ) ) and value.endswith("9") diff --git a/tests/unit/prediction/test_nutrition_extraction.py b/tests/unit/prediction/test_nutrition_extraction.py index 3c7021ff5e..effe567f74 100644 --- a/tests/unit/prediction/test_nutrition_extraction.py +++ b/tests/unit/prediction/test_nutrition_extraction.py @@ -4,6 +4,7 @@ aggregate_entities, match_nutrient_value, postprocess_aggregated_entities, + postprocess_aggregated_entities_single, ) @@ -392,8 +393,64 @@ def test_aggregate_entities_multiple_entities(self): ("25.9", "iron_100g", ("25.9", None, True)), ("O g", "salt_100g", ("0", "g", True)), ("O", "salt_100g", ("0", None, True)), + ("0,19", "saturated_fat_100g", ("0.1", "g", True)), ], ) def test_match_nutrient_value(words_str: str, entity_label: str, expected_output): assert match_nutrient_value(words_str, entity_label) == expected_output + + +@pytest.mark.parametrize( + "aggregated_entity,expected_output", + [ + ( + { + "end": 90, + "score": 0.9985358715057373, + "start": 89, + "words": ["0,19\n"], + "entity": "SATURATED_FAT_100G", + "char_end": 459, + "char_start": 454, + }, + { + "char_end": 459, + "char_start": 454, + "end": 90, + "entity": "saturated-fat_100g", + "score": 0.9985358715057373, + "start": 89, + "text": "0,19", + "unit": "g", + "valid": True, + "value": "0.1", + }, + ), + ( + { + "end": 92, + "score": 0.9985358715057373, + "start": 90, + "words": ["42.5 9"], + "entity": "SERVING_SIZE", + "char_end": 460, + "char_start": 454, + }, + { + "char_end": 460, + "char_start": 454, + "end": 92, + "entity": "serving_size", + "score": 0.9985358715057373, + "start": 90, + "text": "42.5 9", + "unit": None, + "valid": True, + "value": "42.5 g", + }, + ), + ], +) +def test_postprocess_aggregated_entities_single(aggregated_entity, expected_output): + assert postprocess_aggregated_entities_single(aggregated_entity) == expected_output From 9354aac2fee14f0e6a9ec50b3584dd0ae4788bd3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Thu, 5 Dec 2024 10:15:14 +0100 Subject: [PATCH 2/2] fix: optimize rerun_import_all_images job --- robotoff/workers/tasks/import_image.py | 53 +++++++++++++++++++++----- 1 file changed, 43 insertions(+), 10 deletions(-) diff --git a/robotoff/workers/tasks/import_image.py b/robotoff/workers/tasks/import_image.py index 58526b90dd..66fe4dedda 100644 --- a/robotoff/workers/tasks/import_image.py +++ b/robotoff/workers/tasks/import_image.py @@ -92,7 +92,10 @@ def rerun_import_all_images( where_clauses.append(ImageModel.server_type == server_type.name) query = ( ImageModel.select( - ImageModel.barcode, ImageModel.image_id, ImageModel.server_type + ImageModel.id, + ImageModel.barcode, + ImageModel.image_id, + ImageModel.server_type, ) .where(*where_clauses) .order_by(ImageModel.uploaded_at.desc()) @@ -104,18 +107,16 @@ def rerun_import_all_images( if return_count: return query.count() - for barcode, image_id, server_type_str in query: + for image_model_id, barcode, image_id, server_type_str in query: if not isinstance(barcode, str) and not barcode.isdigit(): raise ValueError("Invalid barcode: %s" % barcode) product_id = ProductIdentifier(barcode, ServerType[server_type_str]) image_url = generate_image_url(product_id, image_id) ocr_url = generate_json_ocr_url(product_id, image_id) - enqueue_job( - run_import_image_job, - get_high_queue(product_id), - job_kwargs={"result_ttl": 0}, + run_import_image( product_id=product_id, + image_model_id=image_model_id, image_url=image_url, ocr_url=ocr_url, flags=flags, @@ -144,6 +145,9 @@ def run_import_image_job( What tasks are performed can be controlled using the `flags` parameter. By default, all tasks are performed. A new rq job is enqueued for each task. + Before running the tasks, the image is downloaded and stored in the Robotoff + DB. + :param product_id: the product identifier :param image_url: the URL of the image to import :param ocr_url: the URL of the OCR JSON file @@ -151,9 +155,6 @@ def run_import_image_job( """ logger.info("Running `import_image` for %s, image %s", product_id, image_url) - if flags is None: - flags = [flag for flag in ImportImageFlag] - source_image = get_source_from_url(image_url) product = get_product_store(product_id.server_type)[product_id] if product is None and settings.ENABLE_MONGODB_ACCESS: @@ -185,13 +186,45 @@ def run_import_image_job( ImageModel.bulk_update([image_model], fields=["deleted"]) return + run_import_image( + product_id=product_id, + image_model_id=image_model.id, + image_url=image_url, + ocr_url=ocr_url, + flags=flags, + ) + + +def run_import_image( + product_id: ProductIdentifier, + image_model_id: int, + image_url: str, + ocr_url: str, + flags: list[ImportImageFlag] | None = None, +) -> None: + """Launch all extraction tasks on an image. + + We assume that the image exists in the Robotoff DB. + + What tasks are performed can be controlled using the `flags` parameter. By + default, all tasks are performed. A new rq job is enqueued for each task. + + :param product_id: the product identifier + :param image_model_id: the DB ID of the image + :param image_url: the URL of the image to import + :param ocr_url: the URL of the OCR JSON file + :param flags: the list of flags to run, defaults to None (all) + """ + if flags is None: + flags = [flag for flag in ImportImageFlag] + if ImportImageFlag.add_image_fingerprint in flags: # Compute image fingerprint, this job is low priority enqueue_job( add_image_fingerprint_job, low_queue, job_kwargs={"result_ttl": 0}, - image_model_id=image_model.id, + image_model_id=image_model_id, ) if product_id.server_type.is_food():