From 65beaaf36101fa4922abf74b9ab0d22e42ceb61a Mon Sep 17 00:00:00 2001 From: Boudewijn van Groos Date: Thu, 18 Sep 2025 13:33:22 +0200 Subject: [PATCH 1/2] Fix pydantic generics and add ability to directive json schema --- RELEASE.md | 4 + docs/integrations/pydantic.md | 58 ++++++++ strawberry/experimental/pydantic/_compat.py | 6 + .../experimental/pydantic/exceptions.py | 9 ++ strawberry/experimental/pydantic/fields.py | 5 + .../experimental/pydantic/object_type.py | 60 +++++++- .../pydantic/schema/test_basic.py | 70 ++++++++- .../schema/test_json_schema_capture.py | 140 ++++++++++++++++++ 8 files changed, 348 insertions(+), 4 deletions(-) create mode 100644 RELEASE.md create mode 100644 tests/experimental/pydantic/schema/test_json_schema_capture.py diff --git a/RELEASE.md b/RELEASE.md new file mode 100644 index 0000000000..d49b2a2dd0 --- /dev/null +++ b/RELEASE.md @@ -0,0 +1,4 @@ +Release type: minor + +This release fixes the pydantic support for generics and allows capture of the +Pydantic JSON schema attributes through a schema directive. diff --git a/docs/integrations/pydantic.md b/docs/integrations/pydantic.md index e1c1967772..0f0248bcc5 100644 --- a/docs/integrations/pydantic.md +++ b/docs/integrations/pydantic.md @@ -519,3 +519,61 @@ user_type = UserType(id="abc", content_name="Bob", content_description=None) print(user_type.to_pydantic()) # id='abc' content={: 'Bob'} ``` + +## Schema directives to capture JSON schema data + +The pydantic conversion also supports capturing the JSON schema metadata from +Pydantic. Note the fields in the directive must match the json schema names. + +```python +from pydantic import BaseModel, Field +from typing import Annotated, Optional +from strawberry.schema_directive import Location +import strawberry + + +class User(BaseModel): + id: Annotated[int, Field(gt=0)] + name: Annotated[str, Field(json_schema_extra={"name_type": "full"})] + + +@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) +class MyJsonSchema: + exclusive_minimum: Optional[int] = None + name_type: Optional[str] = None + + +@strawberry.experimental.pydantic.type(model=User, json_schema_directive=MyJsonSchema) +class UserType: + id: strawberry.auto + name: strawberry.auto + + +@strawberry.type +class Query: + @strawberry.field + def test() -> UserType: + return UserType.from_pydantic(User(id=123, name="John Doe")) + + +schema = strawberry.Schema(query=Query) +``` + +Now if [the schema is exported](../guides/schema-export), the result will +contain: + +```graphql +directive @myJsonSchema( + exclusiveMinimum: Int = null + nameType: String = null +) on FIELD_DEFINITION + +type Query { + test: UserType! +} + +type UserType { + id: Int! + name: String! @myJsonSchema(exclusiveMinimum: null, nameType: "full") +} +``` diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index baf4ea14a1..4e660fc7c1 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -178,6 +178,9 @@ def get_model_fields( new_fields |= self.get_model_computed_fields(model) return new_fields + def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]: + return model.model_json_schema() + @cached_property def fields_map(self) -> dict[Any, Any]: return get_fields_map_for_v2() @@ -273,6 +276,9 @@ def get_basic_type(self, type_: Any) -> type[Any]: return type_ + def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]: + return model.schema() + def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]: return model_instance.dict() diff --git a/strawberry/experimental/pydantic/exceptions.py b/strawberry/experimental/pydantic/exceptions.py index 9c54cffc87..13a1f339e5 100644 --- a/strawberry/experimental/pydantic/exceptions.py +++ b/strawberry/experimental/pydantic/exceptions.py @@ -53,3 +53,12 @@ def __init__( ) super().__init__(message) + + +class JsonSchemaDirectiveNotRegistered(Exception): + def __init__(self, json_schema_directive: type) -> None: + message = ( + f"{json_schema_directive} not registered as a strawberry schema directive" + ) + + super().__init__(message) diff --git a/strawberry/experimental/pydantic/fields.py b/strawberry/experimental/pydantic/fields.py index 447dcd9e6a..857c0f0d3d 100644 --- a/strawberry/experimental/pydantic/fields.py +++ b/strawberry/experimental/pydantic/fields.py @@ -46,6 +46,11 @@ def replace_types_recursively( origin = get_origin(type_) if not origin or not hasattr(type_, "__args__"): + if ( + hasattr(basic_type, "__pydantic_generic_metadata__") + and basic_type.__pydantic_generic_metadata__["args"] + ): + return replaced_type[basic_type.__pydantic_generic_metadata__["args"]] return replaced_type converted = tuple( diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index ed7a2e2d11..f4e9b55229 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -8,6 +8,7 @@ Any, Callable, Optional, + TypeVar, cast, ) @@ -20,7 +21,10 @@ convert_pydantic_model_to_strawberry_class, convert_strawberry_class_to_pydantic_model, ) -from strawberry.experimental.pydantic.exceptions import MissingFieldsListError +from strawberry.experimental.pydantic.exceptions import ( + JsonSchemaDirectiveNotRegistered, + MissingFieldsListError, +) from strawberry.experimental.pydantic.fields import replace_types_recursively from strawberry.experimental.pydantic.utils import ( DataclassCreationFields, @@ -33,6 +37,7 @@ from strawberry.types.field import StrawberryField from strawberry.types.object_type import _process_type, _wrap_dataclass from strawberry.types.type_resolver import _get_fields +from strawberry.utils.str_converters import to_snake_case if TYPE_CHECKING: import builtins @@ -62,6 +67,8 @@ def _build_dataclass_creation_fields( auto_fields_set: set[str], use_pydantic_alias: bool, compat: PydanticCompat, + json_schema: dict[str, Any], + json_schema_directive: Optional[builtins.type] = None, ) -> DataclassCreationFields: field_type = ( get_type_for_field(field, is_input, compat=compat) @@ -84,6 +91,21 @@ def _build_dataclass_creation_fields( elif field.has_alias and use_pydantic_alias: graphql_name = field.alias + directives = existing_field.directives if existing_field else () + if ( + json_schema_directive + and json_schema + and ( + json_directive := _generate_json_directive( + json_schema, json_schema_directive + ) + ) + ): + directives = ( + *directives, + json_directive, + ) + strawberry_field = StrawberryField( python_name=field.name, graphql_name=graphql_name, @@ -98,7 +120,7 @@ def _build_dataclass_creation_fields( permission_classes=( existing_field.permission_classes if existing_field else [] ), - directives=existing_field.directives if existing_field else (), + directives=directives, metadata=existing_field.metadata if existing_field else {}, ) @@ -109,6 +131,26 @@ def _build_dataclass_creation_fields( ) +T = TypeVar("T") + + +def _generate_json_directive( + json_schema: dict[str, Any], json_schema_directive: builtins.type[T] +) -> Optional[T]: + if not dataclasses.is_dataclass(json_schema_directive): + raise JsonSchemaDirectiveNotRegistered(json_schema_directive) + + field_names = {field.name for field in dataclasses.fields(json_schema_directive)} + + if applicable_values := { + to_snake_case(key): value + for key, value in json_schema.items() + if to_snake_case(key) in field_names + }: + return json_schema_directive(**applicable_values) + return None + + if TYPE_CHECKING: from strawberry.experimental.pydantic.conversion_types import ( PydanticModel, @@ -128,6 +170,7 @@ def type( all_fields: bool = False, include_computed: bool = False, use_pydantic_alias: bool = True, + json_schema_directive: Optional[Any] = None, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: compat = PydanticCompat.from_model(model) @@ -184,6 +227,11 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: private_fields = get_private_fields(wrapped) extra_fields_dict = {field.name: field for field in extra_strawberry_fields} + fields_json_schema = ( + compat.get_model_json_schema(model).get("properties", {}) + if json_schema_directive + else {} + ) all_model_fields: list[DataclassCreationFields] = [ _build_dataclass_creation_fields( @@ -193,6 +241,8 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: auto_fields_set, use_pydantic_alias, compat=compat, + json_schema_directive=json_schema_directive, + json_schema=fields_json_schema.get(field.name, {}), ) for field_name, field in model_fields.items() if field_name in fields_set @@ -250,10 +300,12 @@ def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool: else: kwargs["init"] = False + bases = cls.__orig_bases__ if hasattr(cls, "__orig_bases__") else cls.__bases__ + cls = dataclasses.make_dataclass( cls.__name__, [field.to_tuple() for field in all_model_fields], - bases=cls.__bases__, + bases=bases, namespace=namespace, **kwargs, # type: ignore ) @@ -317,6 +369,7 @@ def input( directives: Optional[Sequence[object]] = (), all_fields: bool = False, use_pydantic_alias: bool = True, + json_schema_directive: Optional[builtins.type] = None, ) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]: """Convenience decorator for creating an input type from a Pydantic model. @@ -334,6 +387,7 @@ def input( directives=directives, all_fields=all_fields, use_pydantic_alias=use_pydantic_alias, + json_schema_directive=json_schema_directive, ) diff --git a/tests/experimental/pydantic/schema/test_basic.py b/tests/experimental/pydantic/schema/test_basic.py index 91e6d7317e..a8b92b4b06 100644 --- a/tests/experimental/pydantic/schema/test_basic.py +++ b/tests/experimental/pydantic/schema/test_basic.py @@ -1,6 +1,6 @@ import textwrap from enum import Enum -from typing import Optional, Union +from typing import Annotated, Generic, Optional, TypeAlias, TypeVar, Union import pydantic @@ -529,6 +529,74 @@ def user(self) -> User: assert result.data["user"]["password"] is None +def test_nested_type_with_resolved_generic(): + A = TypeVar("A") + + class Hobby(pydantic.BaseModel, Generic[A]): + name: A + + @strawberry.experimental.pydantic.type(Hobby) + class HobbyType(Generic[A]): + name: strawberry.auto + + class User(pydantic.BaseModel): + hobby: Hobby[str] + + @strawberry.experimental.pydantic.type(User) + class UserType: + hobby: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(hobby=HobbyType(name="Skii")) + + schema = strawberry.Schema(query=Query) + + query = "{ user { hobby { name } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["hobby"]["name"] == "Skii" + + +def test_nested_type_with_resolved_field_generic(): + Count: TypeAlias = Annotated[float, pydantic.Field(ge=0)] + + A = TypeVar("A") + + class Hobby(pydantic.BaseModel, Generic[A]): + count: A + + @strawberry.experimental.pydantic.type(Hobby) + class HobbyType(Generic[A]): + count: strawberry.auto + + class User(pydantic.BaseModel): + hobby: Hobby[Count] + + @strawberry.experimental.pydantic.type(User) + class UserType: + hobby: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> UserType: + return UserType(hobby=HobbyType(count=2)) + + schema = strawberry.Schema(query=Query) + + query = "{ user { hobby { count } } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["hobby"]["count"] == 2 + + @needs_pydantic_v1 def test_basic_type_with_constrained_list(): class FriendList(pydantic.ConstrainedList): diff --git a/tests/experimental/pydantic/schema/test_json_schema_capture.py b/tests/experimental/pydantic/schema/test_json_schema_capture.py new file mode 100644 index 0000000000..de8403a554 --- /dev/null +++ b/tests/experimental/pydantic/schema/test_json_schema_capture.py @@ -0,0 +1,140 @@ +import textwrap +from typing import Annotated, Optional, Union + +import pydantic + +import strawberry +from strawberry.schema_directive import Location +from tests.experimental.pydantic.utils import needs_pydantic_v2 + + +def test_basic_type_field_list(): + class UserModel(pydantic.BaseModel): + age: Annotated[int, pydantic.Field(gt=0, json_schema_extra={"test": 0})] + password: pydantic.json_schema.SkipJsonSchema[Optional[str]] + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class JsonSchema: + test: int + exclusive_minimum: Optional[int] = None + + @strawberry.experimental.pydantic.type(UserModel, json_schema_directive=JsonSchema) + class User: + age: strawberry.auto + password: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + directive @jsonSchema(test: Int!, exclusiveMinimum: Int = null) on FIELD_DEFINITION + + type Query { + user: User! + } + + type User { + age: Int! @jsonSchema(test: 0, exclusiveMinimum: 0) + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + + +@needs_pydantic_v2 +def test_can_use_both_pydantic_1_and_2(): + import pydantic + from pydantic import v1 as pydantic_v1 + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class JsonSchema: + minimum: Optional[int] = None + + class UserModel(pydantic.BaseModel): + age: Annotated[int, pydantic.Field(ge=0)] + name: Optional[str] + + @strawberry.experimental.pydantic.type(UserModel, json_schema_directive=JsonSchema) + class User: + age: strawberry.auto + name: strawberry.auto + + class LegacyUserModel(pydantic_v1.BaseModel): + age: int + name: Optional[str] + int_field: pydantic.v1.NonNegativeInt = 1 + + @strawberry.experimental.pydantic.type( + LegacyUserModel, json_schema_directive=JsonSchema + ) + class LegacyUser: + age: strawberry.auto + name: strawberry.auto + int_field: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self, id: strawberry.ID) -> Union[User, LegacyUser]: + if id == "legacy": + return LegacyUser(age=1, name="legacy") + + return User(age=1, name="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + directive @jsonSchema(minimum: Int = null) on FIELD_DEFINITION + + type LegacyUser { + age: Int! + name: String + intField: Int! @jsonSchema(minimum: 0) + } + + type Query { + user(id: ID!): UserLegacyUser! + } + + type User { + age: Int! @jsonSchema(minimum: 0) + name: String + } + + union UserLegacyUser = User | LegacyUser + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = """ + query ($id: ID!) { + user(id: $id) { + __typename + ... on User { name } + ... on LegacyUser { name } + } + } + """ + + result = schema.execute_sync(query, variable_values={"id": "new"}) + + assert not result.errors + assert result.data == {"user": {"__typename": "User", "name": "ABC"}} + + result = schema.execute_sync(query, variable_values={"id": "legacy"}) + + assert not result.errors + assert result.data == {"user": {"__typename": "LegacyUser", "name": "legacy"}} From b33fa676612e3e81fa5a56a1a16ded77c7786191 Mon Sep 17 00:00:00 2001 From: Boudewijn van Groos Date: Fri, 28 Nov 2025 15:22:57 +0100 Subject: [PATCH 2/2] Handle pydantic aliases in directive collection --- strawberry/experimental/pydantic/_compat.py | 12 ++- .../experimental/pydantic/object_type.py | 2 +- strawberry/types/field.py | 1 + .../schema/test_json_schema_capture.py | 100 ++++++++++++++++++ 4 files changed, 110 insertions(+), 5 deletions(-) diff --git a/strawberry/experimental/pydantic/_compat.py b/strawberry/experimental/pydantic/_compat.py index 4e660fc7c1..a93583903a 100644 --- a/strawberry/experimental/pydantic/_compat.py +++ b/strawberry/experimental/pydantic/_compat.py @@ -178,8 +178,10 @@ def get_model_fields( new_fields |= self.get_model_computed_fields(model) return new_fields - def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]: - return model.model_json_schema() + def get_model_json_schema( + self, model: type[BaseModel], by_alias: bool + ) -> dict[str, Any]: + return model.model_json_schema(by_alias=by_alias) @cached_property def fields_map(self) -> dict[Any, Any]: @@ -276,8 +278,10 @@ def get_basic_type(self, type_: Any) -> type[Any]: return type_ - def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]: - return model.schema() + def get_model_json_schema( + self, model: type[BaseModel], by_alias: bool + ) -> dict[str, Any]: + return model.schema(by_alias=by_alias) def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]: return model_instance.dict() diff --git a/strawberry/experimental/pydantic/object_type.py b/strawberry/experimental/pydantic/object_type.py index f4e9b55229..67f9b24737 100644 --- a/strawberry/experimental/pydantic/object_type.py +++ b/strawberry/experimental/pydantic/object_type.py @@ -228,7 +228,7 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]: extra_fields_dict = {field.name: field for field in extra_strawberry_fields} fields_json_schema = ( - compat.get_model_json_schema(model).get("properties", {}) + compat.get_model_json_schema(model, by_alias=False).get("properties", {}) if json_schema_directive else {} ) diff --git a/strawberry/types/field.py b/strawberry/types/field.py index a6d6c1adc2..dcf96bfbcd 100644 --- a/strawberry/types/field.py +++ b/strawberry/types/field.py @@ -374,6 +374,7 @@ def copy_with( override_type = type_.copy_with(type_var_map) elif isinstance(type_, StrawberryType): override_type = type_.copy_with(type_var_map) + # TODO: add support for predefined fields in generic (or field factory in generic) if override_type is not None: new_field.type_annotation = StrawberryAnnotation( diff --git a/tests/experimental/pydantic/schema/test_json_schema_capture.py b/tests/experimental/pydantic/schema/test_json_schema_capture.py index de8403a554..6eccfed3a5 100644 --- a/tests/experimental/pydantic/schema/test_json_schema_capture.py +++ b/tests/experimental/pydantic/schema/test_json_schema_capture.py @@ -138,3 +138,103 @@ def user(self, id: strawberry.ID) -> Union[User, LegacyUser]: assert not result.errors assert result.data == {"user": {"__typename": "LegacyUser", "name": "legacy"}} + + +def test_basic_with_alias_without_using_them(): + class UserModel(pydantic.BaseModel): + age: Annotated[ + int, pydantic.Field(gt=0, json_schema_extra={"test": 0}, alias="userAge") + ] + password: pydantic.json_schema.SkipJsonSchema[Optional[str]] + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class JsonSchema: + test: int + exclusive_minimum: Optional[int] = None + + @strawberry.experimental.pydantic.type( + UserModel, json_schema_directive=JsonSchema, use_pydantic_alias=False + ) + class User: + age: strawberry.auto + password: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + directive @jsonSchema(test: Int!, exclusiveMinimum: Int = null) on FIELD_DEFINITION + + type Query { + user: User! + } + + type User { + age: Int! @jsonSchema(test: 0, exclusiveMinimum: 0) + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { age } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["age"] == 1 + + +def test_basic_with_alias_and_use_them(): + class UserModel(pydantic.BaseModel): + age: Annotated[ + int, pydantic.Field(gt=0, json_schema_extra={"test": 0}, alias="userAge") + ] + password: pydantic.json_schema.SkipJsonSchema[Optional[str]] + + @strawberry.schema_directive(locations=[Location.FIELD_DEFINITION]) + class JsonSchema: + test: int + exclusive_minimum: Optional[int] = None + + @strawberry.experimental.pydantic.type( + UserModel, json_schema_directive=JsonSchema, use_pydantic_alias=True + ) + class User: + age: strawberry.auto + password: strawberry.auto + + @strawberry.type + class Query: + @strawberry.field + def user(self) -> User: + return User(age=1, password="ABC") + + schema = strawberry.Schema(query=Query) + + expected_schema = """ + directive @jsonSchema(test: Int!, exclusiveMinimum: Int = null) on FIELD_DEFINITION + + type Query { + user: User! + } + + type User { + userAge: Int! @jsonSchema(test: 0, exclusiveMinimum: 0) + password: String + } + """ + + assert str(schema) == textwrap.dedent(expected_schema).strip() + + query = "{ user { userAge } }" + + result = schema.execute_sync(query) + + assert not result.errors + assert result.data["user"]["userAge"] == 1