From 232b67362dedd2f66e2c89bf9b715f70166e7393 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Rapha=C3=ABl=20Bournhonesque?= Date: Tue, 3 Dec 2024 15:07:50 +0100 Subject: [PATCH] feat: add nutrition annotation --- robotoff/insights/annotate.py | 95 ++++++++++++++++++++++++++++++++++- robotoff/off.py | 30 ++++++++++- robotoff/types.py | 64 ++++++++++++++++++++++- 3 files changed, 186 insertions(+), 3 deletions(-) diff --git a/robotoff/insights/annotate.py b/robotoff/insights/annotate.py index 0e7f58a814..1637177eb4 100644 --- a/robotoff/insights/annotate.py +++ b/robotoff/insights/annotate.py @@ -7,6 +7,7 @@ from enum import Enum from typing import Optional, Type +from pydantic import ValidationError from requests.exceptions import ConnectionError as RequestConnectionError from requests.exceptions import HTTPError, SSLError, Timeout @@ -20,6 +21,7 @@ add_packaging, add_store, save_ingredients, + save_nutrients, select_rotate_image, unselect_image, update_emb_codes, @@ -27,7 +29,7 @@ update_quantity, ) from robotoff.products import get_image_id, get_product -from robotoff.types import InsightAnnotation, InsightType, JSONType +from robotoff.types import InsightAnnotation, InsightType, JSONType, NutrientData from robotoff.utils import get_logger logger = get_logger(__name__) @@ -53,6 +55,7 @@ class AnnotationStatus(Enum): error_failed_update = 10 error_invalid_data = 11 user_input_updated = 12 + cannot_vote = 13 SAVED_ANNOTATION_RESULT = AnnotationResult( @@ -106,6 +109,11 @@ class AnnotationStatus(Enum): status=AnnotationStatus.error_invalid_data.name, description="The data schema is invalid.", ) +CANNOT_VOTE_RESULT = AnnotationResult( + status_code=AnnotationStatus.cannot_vote.value, + status=AnnotationStatus.cannot_vote.name, + description="The voting mechanism is not compatible with this insight type, please authenticate.", +) class InsightAnnotator(metaclass=abc.ABCMeta): @@ -693,6 +701,90 @@ def process_annotation( return UPDATED_ANNOTATION_RESULT +NUTRIENT_DEFAULT_UNIT = { + "energy-kcal": "kcal", + "energy-kj": "kJ", + "proteins": "g", + "carbohydrates": "g", + "sugars": "g", + "added-sugars": "g", + "fat": "g", + "saturated-fat": "g", + "trans-fat": "g", + "fiber": "g", + "salt": "g", + "iron": "mg", + "sodium": "mg", + "calcium": "mg", + "potassium": "mg", + "cholesterol": "mg", + "vitamin-d": "µg", +} + + +class NutrientExtractionAnnotator(InsightAnnotator): + @classmethod + def process_annotation( + cls, + insight: ProductInsight, + data: dict | None = None, + auth: OFFAuthentication | None = None, + is_vote: bool = False, + ) -> AnnotationResult: + if is_vote: + return CANNOT_VOTE_RESULT + + # The annotator can change the nutrient values to fix the model errors + if data is not None: + try: + validated_nutrients = cls.validate_data(data) + except ValidationError as e: + return AnnotationResult( + status_code=AnnotationStatus.error_invalid_data.value, + status=AnnotationStatus.error_invalid_data.name, + description=str(e), + ) + # We override the predicted nutrient values by the ones submitted by the + # user + insight.data["annotation"] = validated_nutrients.model_dump() + insight.data["was_updated"] = True + insight.save() + else: + validated_nutrients = NutrientData.model_validate(insight.data) + for nutrient_name, nutrient_value in validated_nutrients.nutrients.items(): + if ( + nutrient_value.unit is None + and nutrient_name in NUTRIENT_DEFAULT_UNIT + ): + nutrient_value.unit = NUTRIENT_DEFAULT_UNIT[nutrient_name] + + insight.data["annotation"] = validated_nutrients.model_dump() + insight.data["was_updated"] = False + insight.save() + + save_nutrients( + product_id=insight.get_product_id(), + nutrient_data=validated_nutrients, + insight_id=insight.id, + auth=auth, + is_vote=is_vote, + ) + return UPDATED_ANNOTATION_RESULT + + @classmethod + def validate_data(cls, data: JSONType) -> NutrientData: + """Validate the `data` field submitted by the client. + + :params data: the data submitted by the client + :return: the validated data + + :raises ValidationError: if the data is invalid + """ + if "nutrients" not in data: + raise ValidationError("missing 'nutrients' field") + return NutrientData.model_validate(data) + + ANNOTATOR_MAPPING: dict[str, Type] = { InsightType.packager_code.name: PackagerCodeAnnotator, InsightType.label.name: LabelAnnotator, @@ -705,6 +797,7 @@ def process_annotation( InsightType.nutrition_image.name: NutritionImageAnnotator, InsightType.is_upc_image.name: UPCImageAnnotator, InsightType.ingredient_spellcheck.name: IngredientSpellcheckAnnotator, + InsightType.nutrient_extraction: NutrientExtractionAnnotator, } diff --git a/robotoff/off.py b/robotoff/off.py index c183fc44c3..39836316f5 100644 --- a/robotoff/off.py +++ b/robotoff/off.py @@ -11,7 +11,7 @@ from requests.exceptions import JSONDecodeError from robotoff import settings -from robotoff.types import JSONType, ProductIdentifier, ServerType +from robotoff.types import JSONType, NutrientData, ProductIdentifier, ServerType from robotoff.utils import get_logger, http_session logger = get_logger(__name__) @@ -435,6 +435,34 @@ def save_ingredients( update_product(params, server_type=product_id.server_type, auth=auth, **kwargs) +def save_nutrients( + product_id: ProductIdentifier, + nutrient_data: NutrientData, + insight_id: str | None = None, + auth: OFFAuthentication | None = None, + is_vote: bool = False, + **kwargs, +): + """Save nutrient information for a product.""" + comment = generate_edit_comment( + "Update nutrient values", is_vote, auth is None, insight_id + ) + params = { + "code": product_id.barcode, + "comment": comment, + "nutrition_data_per": nutrient_data.nutrition_data_per, + } + if nutrient_data.serving_size: + params["serving_size"] = nutrient_data.serving_size + + for nutrient_name, nutrient_value in nutrient_data.nutrients.items(): + if nutrient_value.unit: + params[f"nutriment_{nutrient_name}"] = nutrient_value.value + params[f"nutriment_{nutrient_name}_unit"] = nutrient_value.unit + + update_product(params, server_type=product_id.server_type, auth=auth, **kwargs) + + def update_product( params: dict, server_type: ServerType, diff --git a/robotoff/types.py b/robotoff/types.py index 37b9b2b9f0..41db96e633 100644 --- a/robotoff/types.py +++ b/robotoff/types.py @@ -2,7 +2,10 @@ import datetime import enum import uuid -from typing import Any, Literal, Optional +from collections import Counter +from typing import Any, Literal, Optional, Self + +from pydantic import BaseModel, model_validator #: A precise expectation of what mappings looks like in json. #: (dict where keys are always of type `str`). @@ -360,3 +363,62 @@ class BatchJobType(enum.Enum): """Each job type correspond to a task that will be executed in the batch job.""" ingredients_spellcheck = "ingredients-spellcheck" + + +class NutrientSingleValue(BaseModel): + value: str + unit: str | None = None + + +class NutrientData(BaseModel): + nutrients: dict[str, NutrientSingleValue] + serving_size: str | None = None + nutrition_data_per: Literal["100g", "serving"] | None = None + + @model_validator(mode="before") + @classmethod + def move_fields(cls, data: Any) -> Any: + if isinstance(data, dict) and "nutrients" in data: + if "serving_size" in data["nutrients"]: + # In the input data, `serving_size` is a key of the `nutrients` + # while on Product Opener, it's a different field. We move it + # to the root of the dict to be compliant with Product Opener + # API. + serving_size = data["nutrients"].pop("serving_size") + if isinstance(serving_size, dict): + data["serving_size"] = serving_size["value"] + else: + data["serving_size"] = serving_size + return data + + @model_validator(mode="after") + def validate_nutrients(self) -> Self: + if len(self.nutrients) == 0: + raise ValueError("at least one nutrient is required") + + # We expect all nutrient keys to be in the format `{nutrient_name}_{unit}` + # where `unit` is either `100g` or `serving`. + if not all("_" in k for k in self.nutrients): + raise ValueError("each nutrient key must end with '_100g' or '_serving'") + + # We select the as `nutrition_data_per` the most common value between `100g` + # and `serving`. + # When the data is submitted by the client, we expect all nutrient keys to + # have the same unit (either `100g` or `serving`). + # When the annotation is performed directly from the insight (without user + # update or validation), we select the most common key. + nutrition_data_per_count = Counter( + key.rsplit("_", maxsplit=1)[1] for key in self.nutrients.keys() + ) + if set(nutrition_data_per_count.keys()).difference({"100g", "serving"}): + # Some keys are not ending with '_100g' or '_serving' + raise ValueError("each nutrient key must end with '_100g' or '_serving'") + + # Select the most common nutrition data per + self.nutrition_data_per = nutrition_data_per_count.most_common(1)[0][0] # type: ignore + self.nutrients = { + k.rsplit("_", maxsplit=1)[0]: v + for k, v in self.nutrients.items() + if k.endswith(self.nutrition_data_per) # type: ignore + } + return self