Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
Release type: minor

This release fixes the pydantic support for generics and allows capture of the
Pydantic JSON schema attributes through a schema directive.
58 changes: 58 additions & 0 deletions docs/integrations/pydantic.md
Original file line number Diff line number Diff line change
Expand Up @@ -519,3 +519,61 @@ user_type = UserType(id="abc", content_name="Bob", content_description=None)
print(user_type.to_pydantic())
# id='abc' content={<ContentType.NAME: 'name'>: 'Bob'}
```

## Schema directives to capture JSON schema data

The pydantic conversion also supports capturing the JSON schema metadata from
Pydantic. Note the fields in the directive must match the json schema names.

```python
from pydantic import BaseModel, Field
from typing import Annotated, Optional
from strawberry.schema_directive import Location
import strawberry


class User(BaseModel):
id: Annotated[int, Field(gt=0)]
name: Annotated[str, Field(json_schema_extra={"name_type": "full"})]


@strawberry.schema_directive(locations=[Location.FIELD_DEFINITION])
class MyJsonSchema:
exclusive_minimum: Optional[int] = None
name_type: Optional[str] = None


@strawberry.experimental.pydantic.type(model=User, json_schema_directive=MyJsonSchema)
class UserType:
id: strawberry.auto
name: strawberry.auto


@strawberry.type
class Query:
@strawberry.field
def test() -> UserType:
return UserType.from_pydantic(User(id=123, name="John Doe"))


schema = strawberry.Schema(query=Query)
```

Now if [the schema is exported](../guides/schema-export), the result will
contain:

```graphql
directive @myJsonSchema(
exclusiveMinimum: Int = null
nameType: String = null
) on FIELD_DEFINITION

type Query {
test: UserType!
}

type UserType {
id: Int!
name: String! @myJsonSchema(exclusiveMinimum: null, nameType: "full")
}
```
6 changes: 6 additions & 0 deletions strawberry/experimental/pydantic/_compat.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,6 +178,9 @@ def get_model_fields(
new_fields |= self.get_model_computed_fields(model)
return new_fields

def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]:
return model.model_json_schema()

@cached_property
def fields_map(self) -> dict[Any, Any]:
return get_fields_map_for_v2()
Expand Down Expand Up @@ -273,6 +276,9 @@ def get_basic_type(self, type_: Any) -> type[Any]:

return type_

def get_model_json_schema(self, model: type[BaseModel]) -> dict[str, Any]:
return model.schema()

def model_dump(self, model_instance: BaseModel) -> dict[Any, Any]:
return model_instance.dict()

Expand Down
9 changes: 9 additions & 0 deletions strawberry/experimental/pydantic/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,12 @@ def __init__(
)

super().__init__(message)


class JsonSchemaDirectiveNotRegistered(Exception):
def __init__(self, json_schema_directive: type) -> None:
message = (
f"{json_schema_directive} not registered as a strawberry schema directive"
)

super().__init__(message)
5 changes: 5 additions & 0 deletions strawberry/experimental/pydantic/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,11 @@ def replace_types_recursively(
origin = get_origin(type_)

if not origin or not hasattr(type_, "__args__"):
if (
hasattr(basic_type, "__pydantic_generic_metadata__")
and basic_type.__pydantic_generic_metadata__["args"]
):
return replaced_type[basic_type.__pydantic_generic_metadata__["args"]]
return replaced_type

converted = tuple(
Expand Down
60 changes: 57 additions & 3 deletions strawberry/experimental/pydantic/object_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
Any,
Callable,
Optional,
TypeVar,
cast,
)

Expand All @@ -20,7 +21,10 @@
convert_pydantic_model_to_strawberry_class,
convert_strawberry_class_to_pydantic_model,
)
from strawberry.experimental.pydantic.exceptions import MissingFieldsListError
from strawberry.experimental.pydantic.exceptions import (
JsonSchemaDirectiveNotRegistered,
MissingFieldsListError,
)
from strawberry.experimental.pydantic.fields import replace_types_recursively
from strawberry.experimental.pydantic.utils import (
DataclassCreationFields,
Expand All @@ -33,6 +37,7 @@
from strawberry.types.field import StrawberryField
from strawberry.types.object_type import _process_type, _wrap_dataclass
from strawberry.types.type_resolver import _get_fields
from strawberry.utils.str_converters import to_snake_case

if TYPE_CHECKING:
import builtins
Expand Down Expand Up @@ -62,6 +67,8 @@ def _build_dataclass_creation_fields(
auto_fields_set: set[str],
use_pydantic_alias: bool,
compat: PydanticCompat,
json_schema: dict[str, Any],
json_schema_directive: Optional[builtins.type] = None,
) -> DataclassCreationFields:
field_type = (
get_type_for_field(field, is_input, compat=compat)
Expand All @@ -84,6 +91,21 @@ def _build_dataclass_creation_fields(
elif field.has_alias and use_pydantic_alias:
graphql_name = field.alias

directives = existing_field.directives if existing_field else ()
if (
json_schema_directive
and json_schema
and (
json_directive := _generate_json_directive(
json_schema, json_schema_directive
)
)
):
directives = (
*directives,
json_directive,
)

strawberry_field = StrawberryField(
python_name=field.name,
graphql_name=graphql_name,
Expand All @@ -98,7 +120,7 @@ def _build_dataclass_creation_fields(
permission_classes=(
existing_field.permission_classes if existing_field else []
),
directives=existing_field.directives if existing_field else (),
directives=directives,
metadata=existing_field.metadata if existing_field else {},
)

Expand All @@ -109,6 +131,26 @@ def _build_dataclass_creation_fields(
)


T = TypeVar("T")


def _generate_json_directive(
json_schema: dict[str, Any], json_schema_directive: builtins.type[T]
) -> Optional[T]:
if not dataclasses.is_dataclass(json_schema_directive):
raise JsonSchemaDirectiveNotRegistered(json_schema_directive)

field_names = {field.name for field in dataclasses.fields(json_schema_directive)}

if applicable_values := {
to_snake_case(key): value
for key, value in json_schema.items()
if to_snake_case(key) in field_names
}:
return json_schema_directive(**applicable_values)
return None


if TYPE_CHECKING:
from strawberry.experimental.pydantic.conversion_types import (
PydanticModel,
Expand All @@ -128,6 +170,7 @@ def type(
all_fields: bool = False,
include_computed: bool = False,
use_pydantic_alias: bool = True,
json_schema_directive: Optional[Any] = None,
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
compat = PydanticCompat.from_model(model)
Expand Down Expand Up @@ -184,6 +227,11 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
private_fields = get_private_fields(wrapped)

extra_fields_dict = {field.name: field for field in extra_strawberry_fields}
fields_json_schema = (
compat.get_model_json_schema(model).get("properties", {})
if json_schema_directive
else {}
)

all_model_fields: list[DataclassCreationFields] = [
_build_dataclass_creation_fields(
Expand All @@ -193,6 +241,8 @@ def wrap(cls: Any) -> builtins.type[StrawberryTypeFromPydantic[PydanticModel]]:
auto_fields_set,
use_pydantic_alias,
compat=compat,
json_schema_directive=json_schema_directive,
json_schema=fields_json_schema.get(field.name, {}),
)
for field_name, field in model_fields.items()
if field_name in fields_set
Expand Down Expand Up @@ -250,10 +300,12 @@ def is_type_of(cls: builtins.type, obj: Any, _info: GraphQLResolveInfo) -> bool:
else:
kwargs["init"] = False

bases = cls.__orig_bases__ if hasattr(cls, "__orig_bases__") else cls.__bases__

cls = dataclasses.make_dataclass(
cls.__name__,
[field.to_tuple() for field in all_model_fields],
bases=cls.__bases__,
bases=bases,
namespace=namespace,
**kwargs, # type: ignore
)
Expand Down Expand Up @@ -317,6 +369,7 @@ def input(
directives: Optional[Sequence[object]] = (),
all_fields: bool = False,
use_pydantic_alias: bool = True,
json_schema_directive: Optional[builtins.type] = None,
) -> Callable[..., builtins.type[StrawberryTypeFromPydantic[PydanticModel]]]:
"""Convenience decorator for creating an input type from a Pydantic model.

Expand All @@ -334,6 +387,7 @@ def input(
directives=directives,
all_fields=all_fields,
use_pydantic_alias=use_pydantic_alias,
json_schema_directive=json_schema_directive,
)


Expand Down
70 changes: 69 additions & 1 deletion tests/experimental/pydantic/schema/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import textwrap
from enum import Enum
from typing import Optional, Union
from typing import Annotated, Generic, Optional, TypeAlias, TypeVar, Union

import pydantic

Expand Down Expand Up @@ -529,6 +529,74 @@ def user(self) -> User:
assert result.data["user"]["password"] is None


def test_nested_type_with_resolved_generic():
A = TypeVar("A")

class Hobby(pydantic.BaseModel, Generic[A]):
name: A

@strawberry.experimental.pydantic.type(Hobby)
class HobbyType(Generic[A]):
name: strawberry.auto

class User(pydantic.BaseModel):
hobby: Hobby[str]

@strawberry.experimental.pydantic.type(User)
class UserType:
hobby: strawberry.auto

@strawberry.type
class Query:
@strawberry.field
def user(self) -> UserType:
return UserType(hobby=HobbyType(name="Skii"))

schema = strawberry.Schema(query=Query)

query = "{ user { hobby { name } } }"

result = schema.execute_sync(query)

assert not result.errors
assert result.data["user"]["hobby"]["name"] == "Skii"


def test_nested_type_with_resolved_field_generic():
Count: TypeAlias = Annotated[float, pydantic.Field(ge=0)]

A = TypeVar("A")

class Hobby(pydantic.BaseModel, Generic[A]):
count: A

@strawberry.experimental.pydantic.type(Hobby)
class HobbyType(Generic[A]):
count: strawberry.auto

class User(pydantic.BaseModel):
hobby: Hobby[Count]

@strawberry.experimental.pydantic.type(User)
class UserType:
hobby: strawberry.auto

@strawberry.type
class Query:
@strawberry.field
def user(self) -> UserType:
return UserType(hobby=HobbyType(count=2))

schema = strawberry.Schema(query=Query)

query = "{ user { hobby { count } } }"

result = schema.execute_sync(query)

assert not result.errors
assert result.data["user"]["hobby"]["count"] == 2


@needs_pydantic_v1
def test_basic_type_with_constrained_list():
class FriendList(pydantic.ConstrainedList):
Expand Down
Loading
Loading