From c83074da16d529002793fb5ea27ccb80d35572ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Timoth=C3=A9e=20Mazzucotelli?= Date: Sun, 3 Nov 2024 15:44:02 +0100 Subject: [PATCH] feat: Also support `pydantic.model_validator` Issue-4: https://github.com/mkdocstrings/griffe-pydantic/issues/4 --- docs/examples/model_ext.py | 11 ++++++++++- src/griffe_pydantic/static.py | 6 +++--- .../material/_base/pydantic_model.html.jinja | 13 ++++++++----- 3 files changed, 21 insertions(+), 9 deletions(-) diff --git a/docs/examples/model_ext.py b/docs/examples/model_ext.py index 6466777..2110554 100644 --- a/docs/examples/model_ext.py +++ b/docs/examples/model_ext.py @@ -1,4 +1,5 @@ -from pydantic import field_validator, ConfigDict, BaseModel, Field +from typing import Any +from pydantic import field_validator, model_validator, ConfigDict, BaseModel, Field class ExampleModel(BaseModel): @@ -26,3 +27,11 @@ def check_max_length_ten(cls, v) -> str: if len(v) >= 10: raise ValueError("No more than 10 characters allowed") return v + + @model_validator(mode="before") + @classmethod + def lowercase_only(cls, data: dict[str, Any]) -> dict[str, Any]: + """Ensure that the field without a default is lowercase.""" + if isinstance(data.get("field_without_default"), str): + data["field_without_default"] = data["field_without_default"].lower() + return data diff --git a/src/griffe_pydantic/static.py b/src/griffe_pydantic/static.py index 70d6c91..cbbb6cb 100644 --- a/src/griffe_pydantic/static.py +++ b/src/griffe_pydantic/static.py @@ -48,7 +48,7 @@ def inherits_pydantic(cls: Class) -> bool: return any(inherits_pydantic(parent_class) for parent_class in cls.mro()) -def pydantic_field_validator(func: Function) -> ExprCall | None: +def pydantic_validator(func: Function) -> ExprCall | None: """Return a function's `pydantic.field_validator` decorator if it exists. Parameters: @@ -58,7 +58,7 @@ def pydantic_field_validator(func: Function) -> ExprCall | None: A decorator value (Griffe expression). """ for decorator in func.decorators: - if isinstance(decorator.value, ExprCall) and decorator.callable_path == "pydantic.field_validator": + if isinstance(decorator.value, ExprCall) and decorator.callable_path in {"pydantic.field_validator", "pydantic.model_validator"}: return decorator.value return None @@ -110,7 +110,7 @@ def process_function(func: Function, cls: Class, *, processed: set[str]) -> None logger.warning(f"cannot yet process {func}") return - if decorator := pydantic_field_validator(func): + if decorator := pydantic_validator(func): fields = [ast.literal_eval(field) for field in decorator.arguments if isinstance(field, str)] common.process_function(func, cls, fields) diff --git a/src/griffe_pydantic/templates/material/_base/pydantic_model.html.jinja b/src/griffe_pydantic/templates/material/_base/pydantic_model.html.jinja index 5735805..3fd20fc 100644 --- a/src/griffe_pydantic/templates/material/_base/pydantic_model.html.jinja +++ b/src/griffe_pydantic/templates/material/_base/pydantic_model.html.jinja @@ -48,11 +48,14 @@