from __future__ import annotations
import functools
import inspect
from base64 import b32encode
from collections.abc import Iterator, MutableMapping
from copy import copy
from enum import Enum, IntEnum
from hashlib import sha3_224
from typing import TYPE_CHECKING, Any, TypeAlias, cast
from pydantic import ConfigDict, computed_field, create_model
from pydantic import Field as PydanticField
from pydantic.fields import ComputedFieldInfo
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
ForeignKeyFieldInstance,
ManyToManyFieldInstance,
OneToOneFieldInstance,
)
from tortoise.contrib.pydantic.base import PydanticListModel, PydanticModel
from tortoise.contrib.pydantic.descriptions import (
ComputedFieldDescription,
ModelDescription,
PydanticMetaData,
)
from tortoise.contrib.pydantic.utils import get_annotations
from tortoise.exceptions import NoValuesFetched
from tortoise.fields import Field, JSONField
from tortoise.fields.data import CharEnumFieldInstance, IntEnumFieldInstance
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
# Type alias for a single entry in the recursion stack: (model_class, field_name, max_recursion)
StackEntry: TypeAlias = tuple["type[Model]", str, int]
# Type alias for property values stored in _properties.
# Regular fields are stored as (type, FieldInfo) tuples; computed fields as decorator instances.
PropertyValue: TypeAlias = "tuple[type, Any] | Any"
_MODEL_INDEX: dict[str, type[PydanticModel]] = {}
"""
The index works as follows:
1. the hash is calculated from the following:
- the fully qualified name of the model
- the names of the contained fields
- the names of all relational fields and the corresponding names of the pydantic model.
This is because if the model is not yet fully initialized, the relational fields are not yet present.
2. the hash does not take into account the resulting name of the model; this must be checked separately.
3. the hash can only be calculated after a complete analysis of the given model.
"""
def _br_it(val: str) -> str:
return val.replace("\n", "<br/>").strip()
def _cleandoc(obj: Any) -> str:
return _br_it(inspect.cleandoc(obj.__doc__ or ""))
[docs]
class FieldMap(MutableMapping[str, Field | ComputedFieldDescription]):
def __init__(self, meta: PydanticMetaData, pk_field: Field | None = None) -> None:
self._field_map: dict[str, Field | ComputedFieldDescription] = {}
self.pk_raw_field = pk_field.model_field_name if pk_field is not None else ""
if pk_field:
self.pk_raw_field = pk_field.model_field_name
self.field_map_update([pk_field], meta)
self.computed_fields: dict[str, ComputedFieldDescription] = {}
def __delitem__(self, __key: str) -> None:
self._field_map.__delitem__(__key)
def __getitem__(self, __key: str) -> Field | ComputedFieldDescription:
return self._field_map.__getitem__(__key)
def __len__(self) -> int: # pragma: no-coverage
return self._field_map.__len__()
def __iter__(self) -> Iterator[str]:
return self._field_map.__iter__()
def __setitem__(self, __key: str, __value: Field | ComputedFieldDescription) -> None:
self._field_map.__setitem__(__key, __value)
def sort_alphabetically(self) -> None:
self._field_map = {k: self._field_map[k] for k in sorted(self._field_map)}
def sort_definition_order(self, cls: type[Model], computed: tuple[str, ...]) -> None:
self._field_map = {
k: self._field_map[k]
for k in tuple(cls._meta.fields_map.keys()) + computed
if k in self._field_map
}
def field_map_update(self, fields: list[Field], meta: PydanticMetaData) -> None:
for field in fields:
name = field.model_field_name
# Include or exclude field
if (meta.include and name not in meta.include) or name in meta.exclude:
continue
# Remove raw fields
if isinstance(field, ForeignKeyFieldInstance):
raw_field = field.source_field
if (
raw_field is not None
and meta.exclude_raw_fields
and raw_field != self.pk_raw_field
):
self.pop(raw_field, None)
self[name] = field
def computed_field_map_update(self, computed: tuple[str, ...], cls: type[Model]) -> None:
self._field_map.update(
{
k: ComputedFieldDescription(
function=getattr(cls, k),
description=None,
)
for k in computed
}
)
[docs]
def pydantic_queryset_creator(
cls: type[Model],
*,
name: str | None = None,
exclude: tuple[str, ...] = (),
include: tuple[str, ...] = (),
computed: tuple[str, ...] = (),
allow_cycles: bool | None = None,
sort_alphabetically: bool | None = None,
) -> type[PydanticListModel]:
"""
Function to build a `Pydantic Model <https://docs.pydantic.dev/latest/concepts/models/>`__ list off Tortoise Model.
:param cls: The Tortoise Model to put in a list.
:param name: Specify a custom name explicitly, instead of a generated name.
The list generated name is currently naive and merely adds a "s" to the end
of the singular name.
:param exclude: Extra fields to exclude from the provided model.
:param include: Extra fields to include from the provided model.
:param computed: Extra computed fields to include from the provided model.
:param allow_cycles: Do we allow any cycles in the generated model?
This is only useful for recursive/self-referential models.
A value of ``False`` (the default) will prevent any and all backtracking.
:param sort_alphabetically: Sort the parameters alphabetically instead of Field-definition order.
The default order would be:
* Field definition order +
* order of reverse relations (as discovered) +
* order of computed functions (as provided).
"""
submodel = pydantic_model_creator(
cls,
exclude=exclude,
include=include,
computed=computed,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
name=name,
)
lname = name or f"{submodel.__name__}_list"
model = create_model(
lname,
__base__=PydanticListModel,
root=(list[submodel], PydanticField(default_factory=list)), # type: ignore
)
model.__doc__ = _cleandoc(cls)
model.model_config["title"] = name or f"{submodel.model_config['title']}_list"
model.model_config["submodel"] = submodel # type: ignore
return model
class PydanticModelCreator:
def __init__(
self,
cls: type[Model],
name: str | None = None,
exclude: tuple[str, ...] | None = None,
include: tuple[str, ...] | None = None,
computed: tuple[str, ...] | None = None,
optional: tuple[str, ...] | None = None,
allow_cycles: bool | None = None,
sort_alphabetically: bool | None = None,
exclude_readonly: bool = False,
meta_override: type | None = None,
model_config: ConfigDict | None = None,
validators: dict[str, Any] | None = None,
module: str = __name__,
_stack: tuple[StackEntry, ...] = (),
_as_submodel: bool = False,
) -> None:
self._cls: type[Model] = cls
self._stack: tuple[StackEntry, ...] = _stack
self._is_default: bool = (
exclude is None
and include is None
and computed is None
and optional is None
and sort_alphabetically is None
and allow_cycles is None
and meta_override is None
and not exclude_readonly
)
if exclude is None:
exclude = ()
if include is None:
include = ()
if computed is None:
computed = ()
if optional is None:
optional = ()
if meta := getattr(cls, "PydanticMeta", None):
meta_from_class = PydanticMetaData.from_pydantic_meta(meta)
else: # default
meta_from_class = PydanticMetaData()
if meta_override:
meta_from_class = meta_from_class.construct_pydantic_meta(meta_override)
self.meta = meta_from_class.finalize_meta(
exclude=exclude,
include=include,
computed=computed,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
model_config=model_config,
)
self._exclude_read_only: bool = exclude_readonly
self._fqname = cls.__module__ + "." + cls.__qualname__
self._name: str
self._title: str
self.given_name = name
self.__hash: str = ""
self._as_submodel = _as_submodel
self._annotations = get_annotations(cls)
self._pconfig: ConfigDict
self._properties: dict[str, PropertyValue] = dict()
self._relational_fields_index: list[tuple[str, str]] = list()
self._model_description: ModelDescription = ModelDescription.from_model(cls)
self._field_map: FieldMap = self._initialize_field_map()
self._construct_field_map()
self._optional = optional
self._validators = validators
self._module = module
self._stack = _stack
@property
def _hash(self) -> str:
if self.__hash == "":
field_info = []
for name, prop in self._properties.items():
if isinstance(prop, tuple):
field_info.append(f"{name}:{prop[0]}")
else:
field_info.append(f"{name}:computed")
hashval = (
f"{self._fqname};"
f"{field_info};"
f"{self._relational_fields_index};"
f"{self._optional};"
f"{self.meta.allow_cycles};"
f"{self._exclude_read_only};"
f"{self.meta.computed}"
)
self.__hash = (
b32encode(sha3_224(hashval.encode("utf-8")).digest()).decode("utf-8").lower()[:6]
)
return self.__hash
def get_name(self) -> tuple[str, str]:
# If arguments are specified (different from the defaults), we append a hash to the
# class name, to make it unique
# We don't check by stack, as cycles get explicitly renamed.
# When called later, include is explicitly set, so fence passes.
if self.given_name is not None:
return self.given_name, self.given_name
name = f"{self._fqname}:{self._hash}" if not self._is_default else self._fqname
name = f"{name}:leaf" if self._as_submodel else name
return name, self._cls.__name__
def _initialize_pconfig(self) -> ConfigDict:
pconfig: ConfigDict = PydanticModel.model_config.copy()
if self.meta.model_config:
pconfig.update(self.meta.model_config)
if "title" not in pconfig:
pconfig["title"] = self._title
if "extra" not in pconfig:
pconfig["extra"] = "forbid"
return pconfig
def _initialize_field_map(self) -> FieldMap:
return (
FieldMap(self.meta)
if self._exclude_read_only
else FieldMap(self.meta, pk_field=self._model_description.pk_field)
)
def _construct_field_map(self) -> None:
self._field_map.field_map_update(fields=self._model_description.data_fields, meta=self.meta)
if not self._exclude_read_only:
for fields in (
self._model_description.fk_fields,
self._model_description.o2o_fields,
self._model_description.m2m_fields,
):
self._field_map.field_map_update(fields, self.meta)
if self.meta.backward_relations:
for fields in (
self._model_description.backward_fk_fields,
self._model_description.backward_o2o_fields,
):
self._field_map.field_map_update(fields, self.meta)
self._field_map.computed_field_map_update(self.meta.computed, self._cls)
if self.meta.sort_alphabetically:
self._field_map.sort_alphabetically()
else:
self._field_map.sort_definition_order(self._cls, self.meta.computed)
def create_pydantic_model(self) -> type[PydanticModel]:
for field_name, field in self._field_map.items():
self._process_field(field_name, field)
self._name, self._title = self.get_name()
if self._hash in _MODEL_INDEX:
hashed_model = _MODEL_INDEX[self._hash]
if hashed_model.__name__ == self._name:
return _MODEL_INDEX[self._hash]
self._pconfig = self._initialize_pconfig()
computed_fields: dict[str, Any] = {}
common_fields: dict[str, Any] = {}
for k, v in self._properties.items():
if isinstance(getattr(v, "decorator_info", None), ComputedFieldInfo):
computed_fields[k] = v
else:
common_fields[k] = v
base_model = type(
"BasePydanticModel",
(PydanticModel,),
{"model_config": self._pconfig, **computed_fields},
)
model: type[PydanticModel] = create_model(
self._name,
__base__=base_model,
__module__=self._module,
__validators__=self._validators,
**common_fields,
)
model.__doc__ = _cleandoc(self._cls)
model.model_config["orig_model"] = self._cls # type: ignore
_MODEL_INDEX[self._hash] = model
return model
def _process_field(
self,
field_name: str,
field: Field | ComputedFieldDescription,
) -> None:
if isinstance(field, Field):
self._process_orm_field(field_name, field)
elif isinstance(field, ComputedFieldDescription):
self._process_computed_field_entry(field_name, field)
def _process_orm_field(self, field_name: str, field: Field) -> None:
json_schema_extra: dict[str, Any] = {}
fconfig: dict[str, Any] = {
"json_schema_extra": json_schema_extra,
}
field_property, _ = self._process_normal_field(
field_name, field, json_schema_extra, fconfig
)
if field_property:
fconfig["title"] = field_name.replace("_", " ").title()
description = _br_it(field.docstring or field.description or "")
if description:
fconfig["description"] = description
if field_name in self._optional or (
field.default is not None and not callable(field.default)
):
self._properties[field_name] = (
field_property,
PydanticField(default=field.default, **fconfig),
)
else:
if json_schema_extra.get("nullable") or (
self._exclude_read_only and json_schema_extra.get("readOnly")
):
# see: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
fconfig["default"] = None
self._properties[field_name] = (field_property, PydanticField(**fconfig))
def _process_computed_field_entry(
self, field_name: str, field: ComputedFieldDescription
) -> None:
field_property = self._process_computed_field(field)
if field_property:
self._properties[field_name] = field_property
def _process_normal_field(
self,
field_name: str,
field: Field,
json_schema_extra: dict[str, Any],
fconfig: dict[str, Any],
) -> tuple[Any | None, bool]:
if isinstance(
field, (ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation)
):
return self._process_single_field_relation(field_name, field, json_schema_extra), True
elif isinstance(field, (BackwardFKRelation, ManyToManyFieldInstance)):
return self._process_many_field_relation(field_name, field), False
elif field.field_type is JSONField:
return Any, False
return self._process_data_field(field_name, field, json_schema_extra, fconfig), False
def _process_single_field_relation(
self,
field_name: str,
field: ForeignKeyFieldInstance | OneToOneFieldInstance | BackwardOneToOneRelation,
json_schema_extra: dict[str, Any],
) -> type[PydanticModel] | None:
python_type = getattr(field, "related_model", field.field_type)
model: type[PydanticModel] | None = self._get_submodel(python_type, field_name)
if model:
self._relational_fields_index.append((field_name, model.__name__))
if field.null:
json_schema_extra["nullable"] = True
if field.null or field.default is not None:
return cast(type[PydanticModel] | None, model | None)
return model
return None
def _process_many_field_relation(
self,
field_name: str,
field: BackwardFKRelation | ManyToManyFieldInstance,
) -> type[list[type[PydanticModel]]] | None:
python_type = field.related_model
model = self._get_submodel(python_type, field_name)
if model:
self._relational_fields_index.append((field_name, model.__name__))
return list[model] # type: ignore
return None
def _process_data_field(
self,
field_name: str,
field: Field,
json_schema_extra: dict[str, Any],
fconfig: dict[str, Any],
) -> Any | None:
annotation = self._annotations.get(field_name, None)
constraints = copy(field.constraints)
if "readOnly" in constraints:
json_schema_extra["readOnly"] = constraints["readOnly"]
del constraints["readOnly"]
fconfig.update(constraints)
python_type: type[Enum] | type[IntEnum] | type
if isinstance(field, (IntEnumFieldInstance, CharEnumFieldInstance)):
python_type = field.enum_type
else:
python_type = getattr(field, "related_model", field.field_type)
ptype = python_type
if field.null:
json_schema_extra["nullable"] = True
if not field.pk and (field_name in self._optional or field.null):
ptype = ptype | None
if not (self._exclude_read_only and json_schema_extra.get("readOnly") is True):
return annotation or ptype
return None
def _process_computed_field(
self,
field: ComputedFieldDescription,
) -> Any | None:
func = field.function
annotation = get_annotations(self._cls, func).get("return", None)
if annotation is not None:
original_func = func
@functools.wraps(original_func)
def wrapped_func(self_pydantic):
orm_obj = getattr(self_pydantic, "__orm_obj__", None)
if orm_obj is not None:
try:
return original_func(orm_obj)
except NoValuesFetched:
raise NoValuesFetched(
f"Computed field '{original_func.__name__}' tried to access a "
f"relation that has not been fetched. Either include the relation "
f"in the Pydantic model so it is auto-prefetched, or call "
f"fetch_related() before serialization."
)
return original_func(self_pydantic)
comment = _cleandoc(func)
c_f = computed_field(return_type=annotation, description=comment)
return c_f(wrapped_func)
return None
@staticmethod
def _create_submodel(
cls: type[Model],
*,
stack: tuple[StackEntry, ...],
exclude: tuple[str, ...] = (),
include: tuple[str, ...] = (),
computed: tuple[str, ...] = (),
name: str | None = None,
allow_cycles: bool = False,
sort_alphabetically: bool | None = None,
) -> type[PydanticModel] | None:
"""Create a Pydantic submodel with recursion protection against cyclic references."""
if not allow_cycles and cls in (c[0] for c in stack[:-1]):
return None
level = 1
for _, _, parent_max_recursion in stack[1:]:
if level >= parent_max_recursion:
return None
level += 1
pmc = PydanticModelCreator(
cls,
exclude=exclude,
include=include,
computed=computed,
name=name,
_stack=stack,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
_as_submodel=True,
)
return pmc.create_pydantic_model()
def _get_submodel(
self, _model: type[Model] | None, field_name: str
) -> type[PydanticModel] | None:
"""Get Pydantic model for the submodel"""
if _model:
new_stack = self._stack + ((self._cls, field_name, self.meta.max_recursion),)
prefix_len = len(field_name) + 1
def get_fields_to_carry_on(field_tuple: tuple[str, ...]) -> tuple[str, ...]:
return tuple(
str(v[prefix_len:]) for v in field_tuple if v.startswith(field_name + ".")
)
pmodel = self._create_submodel(
_model,
exclude=get_fields_to_carry_on(self.meta.exclude),
include=get_fields_to_carry_on(self.meta.include),
computed=get_fields_to_carry_on(self.meta.computed),
stack=new_stack,
allow_cycles=self.meta.allow_cycles,
sort_alphabetically=self.meta.sort_alphabetically,
)
else:
pmodel = None
if pmodel is None:
self.meta.exclude += (field_name,)
return pmodel
[docs]
def pydantic_model_creator(
cls: type[Model],
*,
name: str | None = None,
exclude: tuple[str, ...] | None = None,
include: tuple[str, ...] | None = None,
computed: tuple[str, ...] | None = None,
optional: tuple[str, ...] | None = None,
allow_cycles: bool | None = None,
sort_alphabetically: bool | None = None,
exclude_readonly: bool = False,
meta_override: type | None = None,
model_config: ConfigDict | None = None,
validators: dict[str, Any] | None = None,
module: str = __name__,
) -> type[PydanticModel]:
"""
Function to build `Pydantic Model <https://docs.pydantic.dev/latest/concepts/models/>`__ off Tortoise Model.
:param cls: The Tortoise Model
:param name: Specify a custom name explicitly, instead of a generated name.
:param exclude: Extra fields to exclude from the provided model.
:param include: Extra fields to include from the provided model.
:param computed: Extra computed fields to include from the provided model.
:param optional: Extra optional fields for the provided model.
:param allow_cycles: Do we allow any cycles in the generated model?
This is only useful for recursive/self-referential models.
A value of ``False`` (the default) will prevent any and all backtracking.
:param sort_alphabetically: Sort the parameters alphabetically instead of Field-definition order.
The default order would be:
* Field definition order +
* order of reverse relations (as discovered) +
* order of computed functions (as provided).
:param exclude_readonly: Build a subset model that excludes any readonly fields
:param meta_override: A PydanticMeta class to override model's values.
:param model_config: A custom config to use as pydantic config.
:param validators: A dictionary of methods that validate fields.
:param module: The name of the module that the model belongs to.
Note: Created pydantic model uses config_class parameter and PydanticMeta's
config_class as its Config class's bases(Only if provided!), but it
ignores ``fields`` config. pydantic_model_creator will generate fields by
include/exclude/computed parameters automatically.
"""
pmc = PydanticModelCreator(
cls=cls,
name=name,
exclude=exclude,
include=include,
computed=computed,
optional=optional,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
exclude_readonly=exclude_readonly,
meta_override=meta_override,
model_config=model_config,
validators=validators,
module=module,
)
return pmc.create_pydantic_model()