import inspect
from base64 import b32encode
from copy import copy
from enum import IntEnum, Enum
from hashlib import sha3_224
from typing import (
TYPE_CHECKING,
Any,
Dict,
List,
MutableMapping,
Optional,
Tuple,
Type,
Union,
)
from pydantic import ConfigDict
from pydantic import Field as PydanticField
from pydantic import computed_field, create_model
from tortoise import (
BackwardFKRelation,
BackwardOneToOneRelation,
ForeignKeyFieldInstance,
ManyToManyFieldInstance,
OneToOneFieldInstance,
)
from tortoise.contrib.pydantic.base import PydanticListModel, PydanticModel
from tortoise.contrib.pydantic.descriptions import (
ComputedFieldDescription,
ModelDescription,
PydanticMetaData,
)
from tortoise.contrib.pydantic.utils import get_annotations
from tortoise.fields import Field, JSONField
from tortoise.fields.data import IntEnumFieldInstance, CharEnumFieldInstance
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
_MODEL_INDEX: Dict[str, Type[PydanticModel]] = {}
"""
The index works as follows:
1. the hash is calculated from the following:
- the fully qualified name of the model
- the names of the contained fields
- the names of all relational fields and the corresponding names of the pydantic model.
This is because if the model is not yet fully initialized, the relational fields are not yet present.
2. the hash does not take into account the resulting name of the model; this must be checked separately.
3. the hash can only be calculated after a complete analysis of the given model.
"""
def _br_it(val: str) -> str:
return val.replace("\n", "<br/>").strip()
def _cleandoc(obj: Any) -> str:
return _br_it(inspect.cleandoc(obj.__doc__ or ""))
def _pydantic_recursion_protector(
cls: "Type[Model]",
*,
stack: Tuple,
exclude: Tuple[str, ...] = (),
include: Tuple[str, ...] = (),
computed: Tuple[str, ...] = (),
name=None,
allow_cycles: bool = False,
sort_alphabetically: Optional[bool] = None,
) -> Optional[Type[PydanticModel]]:
"""
It is an inner function to protect pydantic model creator against cyclic recursion
"""
if not allow_cycles and cls in (c[0] for c in stack[:-1]):
return None
caller_fname = stack[0][1]
prop_path = [caller_fname] # It stores the fields in the hierarchy
level = 1
for _, parent_fname, parent_max_recursion in stack[1:]:
# Check recursion level
prop_path.insert(0, parent_fname)
if level >= parent_max_recursion:
# This is too verbose, Do we even need a way of reporting truncated models?
# tortoise.logger.warning(
# "Recursion level %i has reached for model %s",
# level,
# parent_cls.__qualname__ + "." + ".".join(prop_path),
# )
return None
level += 1
pmc = PydanticModelCreator(
cls,
exclude=exclude,
include=include,
computed=computed,
name=name,
_stack=stack,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
_as_submodel=True,
)
return pmc.create_pydantic_model()
[docs]class FieldMap(MutableMapping[str, Union[Field, ComputedFieldDescription]]):
def __init__(self, meta: PydanticMetaData, pk_field: Optional[Field] = None):
self._field_map: Dict[str, Union[Field, ComputedFieldDescription]] = {}
self.pk_raw_field = pk_field.model_field_name if pk_field is not None else ""
if pk_field:
self.pk_raw_field = pk_field.model_field_name
self.field_map_update([pk_field], meta)
self.computed_fields: Dict[str, ComputedFieldDescription] = {}
def __delitem__(self, __key):
self._field_map.__delitem__(__key)
def __getitem__(self, __key):
return self._field_map.__getitem__(__key)
def __len__(self): # pragma: no-coverage
return self._field_map.__len__()
def __iter__(self):
return self._field_map.__iter__()
def __setitem__(self, __key, __value):
self._field_map.__setitem__(__key, __value)
def sort_alphabetically(self) -> None:
self._field_map = {k: self._field_map[k] for k in sorted(self._field_map)}
def sort_definition_order(self, cls: "Type[Model]", computed: Tuple[str, ...]) -> None:
self._field_map = {
k: self._field_map[k]
for k in tuple(cls._meta.fields_map.keys()) + computed
if k in self._field_map
}
def field_map_update(self, fields: List[Field], meta: PydanticMetaData) -> None:
for field in fields:
name = field.model_field_name
# Include or exclude field
if (meta.include and name not in meta.include) or name in meta.exclude:
continue
# Remove raw fields
if isinstance(field, ForeignKeyFieldInstance):
raw_field = field.source_field
if (
raw_field is not None
and meta.exclude_raw_fields
and raw_field != self.pk_raw_field
):
self.pop(raw_field, None)
self[name] = field
def computed_field_map_update(self, computed: Tuple[str, ...], cls: "Type[Model]"):
self._field_map.update(
{
k: ComputedFieldDescription(
field_type=callable,
function=getattr(cls, k),
description=None,
)
for k in computed
}
)
[docs]def pydantic_queryset_creator(
cls: "Type[Model]",
*,
name=None,
exclude: Tuple[str, ...] = (),
include: Tuple[str, ...] = (),
computed: Tuple[str, ...] = (),
allow_cycles: Optional[bool] = None,
sort_alphabetically: Optional[bool] = None,
) -> Type[PydanticListModel]:
"""
Function to build a `Pydantic Model <https://docs.pydantic.dev/latest/concepts/models/>`__ list off Tortoise Model.
:param cls: The Tortoise Model to put in a list.
:param name: Specify a custom name explicitly, instead of a generated name.
The list generated name is currently naive and merely adds a "s" to the end
of the singular name.
:param exclude: Extra fields to exclude from the provided model.
:param include: Extra fields to include from the provided model.
:param computed: Extra computed fields to include from the provided model.
:param allow_cycles: Do we allow any cycles in the generated model?
This is only useful for recursive/self-referential models.
A value of ``False`` (the default) will prevent any and all backtracking.
:param sort_alphabetically: Sort the parameters alphabetically instead of Field-definition order.
The default order would be:
* Field definition order +
* order of reverse relations (as discovered) +
* order of computed functions (as provided).
"""
submodel = pydantic_model_creator(
cls,
exclude=exclude,
include=include,
computed=computed,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
name=name,
)
lname = name or f"{submodel.__name__}_list"
# Creating Pydantic class for the properties generated before
model = create_model(
lname,
__base__=PydanticListModel,
root=(List[submodel], PydanticField(default_factory=list)), # type: ignore
)
# Copy the Model docstring over
model.__doc__ = _cleandoc(cls)
# The title of the model to hide the hash postfix
model.model_config["title"] = name or f"{submodel.model_config['title']}_list"
model.model_config["submodel"] = submodel # type: ignore
return model
class PydanticModelCreator:
def __init__(
self,
cls: "Type[Model]",
name: Optional[str] = None,
exclude: Optional[Tuple[str, ...]] = None,
include: Optional[Tuple[str, ...]] = None,
computed: Optional[Tuple[str, ...]] = None,
optional: Optional[Tuple[str, ...]] = None,
allow_cycles: Optional[bool] = None,
sort_alphabetically: Optional[bool] = None,
exclude_readonly: bool = False,
meta_override: Optional[Type] = None,
model_config: Optional[ConfigDict] = None,
validators: Optional[Dict[str, Any]] = None,
module: str = __name__,
_stack: tuple = (),
_as_submodel: bool = False,
) -> None:
self._cls: "Type[Model]" = cls
self._stack: Tuple[Tuple["Type[Model]", str, int], ...] = (
_stack # ((Type[Model], field_name, max_recursion),)
)
self._is_default: bool = (
exclude is None
and include is None
and computed is None
and optional is None
and sort_alphabetically is None
and allow_cycles is None
and meta_override is None
and not exclude_readonly
)
if exclude is None:
exclude = ()
if include is None:
include = ()
if computed is None:
computed = ()
if optional is None:
optional = ()
if meta := getattr(cls, "PydanticMeta", None):
meta_from_class = PydanticMetaData.from_pydantic_meta(meta)
else: # default
meta_from_class = PydanticMetaData()
if meta_override:
meta_from_class = meta_from_class.construct_pydantic_meta(meta_override)
self.meta = meta_from_class.finalize_meta(
exclude=exclude,
include=include,
computed=computed,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
model_config=model_config,
)
self._exclude_read_only: bool = exclude_readonly
self._fqname = cls.__module__ + "." + cls.__qualname__
self._name: str
self._title: str
self.given_name = name
self.__hash: str = ""
self._as_submodel = _as_submodel
self._annotations = get_annotations(cls)
self._pconfig: ConfigDict
self._properties: Dict[str, Any] = dict()
self._relational_fields_index: List[Tuple[str, str]] = list()
self._model_description: ModelDescription = ModelDescription.from_model(cls)
self._field_map: FieldMap = self._initialize_field_map()
self._construct_field_map()
self._optional = optional
self._validators = validators
self._module = module
self._stack = _stack
@property
def _hash(self):
if self.__hash == "":
hashval = (
f"{self._fqname};{self._properties.keys()};{self._relational_fields_index};{self._optional};"
f"{self.meta.allow_cycles}"
)
self.__hash = (
b32encode(sha3_224(hashval.encode("utf-8")).digest()).decode("utf-8").lower()[:6]
)
return self.__hash
def get_name(self) -> Tuple[str, str]:
# If arguments are specified (different from the defaults), we append a hash to the
# class name, to make it unique
# We don't check by stack, as cycles get explicitly renamed.
# When called later, include is explicitly set, so fence passes.
if self.given_name is not None:
return self.given_name, self.given_name
name = f"{self._fqname}:{self._hash}" if not self._is_default else self._fqname
name = f"{name}:leaf" if self._as_submodel else name
return name, self._cls.__name__
def _initialize_pconfig(self) -> ConfigDict:
pconfig: ConfigDict = PydanticModel.model_config.copy()
if self.meta.model_config:
pconfig.update(self.meta.model_config)
if "title" not in pconfig:
pconfig["title"] = self._title
if "extra" not in pconfig:
pconfig["extra"] = "forbid"
return pconfig
def _initialize_field_map(self) -> FieldMap:
return (
FieldMap(self.meta)
if self._exclude_read_only
else FieldMap(self.meta, pk_field=self._model_description.pk_field)
)
def _construct_field_map(self) -> None:
self._field_map.field_map_update(fields=self._model_description.data_fields, meta=self.meta)
if not self._exclude_read_only:
for fields in (
self._model_description.fk_fields,
self._model_description.o2o_fields,
self._model_description.m2m_fields,
):
self._field_map.field_map_update(fields, self.meta)
if self.meta.backward_relations:
for fields in (
self._model_description.backward_fk_fields,
self._model_description.backward_o2o_fields,
):
self._field_map.field_map_update(fields, self.meta)
self._field_map.computed_field_map_update(self.meta.computed, self._cls)
if self.meta.sort_alphabetically:
self._field_map.sort_alphabetically()
else:
self._field_map.sort_definition_order(self._cls, self.meta.computed)
def create_pydantic_model(self) -> Type[PydanticModel]:
for field_name, field in self._field_map.items():
self._process_field(field_name, field)
self._name, self._title = self.get_name()
if self._hash in _MODEL_INDEX:
# there is a model exactly the same, but the name could be different
hashed_model = _MODEL_INDEX[self._hash]
if hashed_model.__name__ == self._name:
# also the same name
return _MODEL_INDEX[self._hash]
self._pconfig = self._initialize_pconfig()
self._properties["model_config"] = self._pconfig
model = create_model(
self._name,
__base__=PydanticModel,
__module__=self._module,
__validators__=self._validators,
**self._properties,
)
# Copy the Model docstring over
model.__doc__ = _cleandoc(self._cls)
# Store the base class
model.model_config["orig_model"] = self._cls # type: ignore
# Store model reference so we can de-dup it later on if needed.
_MODEL_INDEX[self._hash] = model
return model
def _process_field(
self,
field_name: str,
field: Union[Field, ComputedFieldDescription],
) -> None:
json_schema_extra: Dict[str, Any] = {}
fconfig: Dict[str, Any] = {
"json_schema_extra": json_schema_extra,
}
field_property: Optional[Any] = None
is_to_one_relation: bool = False
if isinstance(field, Field):
field_property, is_to_one_relation = self._process_normal_field(
field_name, field, json_schema_extra, fconfig
)
if field_property:
fconfig["title"] = field_name.replace("_", " ").title()
description = _br_it(field.docstring or field.description or "")
if description:
fconfig["description"] = description
if field_name in self._optional or (
field.default is not None and not callable(field.default)
):
self._properties[field_name] = (
field_property,
PydanticField(default=field.default, **fconfig),
)
else:
if (json_schema_extra.get("nullable") and not is_to_one_relation) or (
self._exclude_read_only and json_schema_extra.get("readOnly")
):
# see: https://docs.pydantic.dev/latest/migration/#required-optional-and-nullable-fields
fconfig["default"] = None
self._properties[field_name] = (field_property, PydanticField(**fconfig))
elif isinstance(field, ComputedFieldDescription):
field_property, is_to_one_relation = self._process_computed_field(field), False
if field_property:
comment = _cleandoc(field.function)
fconfig["title"] = field_name.replace("_", " ").title()
description = comment or _br_it(field.description or "")
if description:
fconfig["description"] = description
self._properties[field_name] = field_property
def _process_normal_field(
self,
field_name: str,
field: Field,
json_schema_extra: Dict[str, Any],
fconfig: Dict[str, Any],
) -> Tuple[Optional[Any], bool]:
if isinstance(
field, (ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation)
):
return self._process_single_field_relation(field_name, field, json_schema_extra), True
elif isinstance(field, (BackwardFKRelation, ManyToManyFieldInstance)):
return self._process_many_field_relation(field_name, field), False
elif field.field_type is JSONField:
return Any, False
return self._process_data_field(field_name, field, json_schema_extra, fconfig), False
def _process_single_field_relation(
self,
field_name: str,
field: Union[ForeignKeyFieldInstance, OneToOneFieldInstance, BackwardOneToOneRelation],
json_schema_extra: Dict[str, Any],
) -> Optional[Type[PydanticModel]]:
python_type = getattr(field, "related_model", field.field_type)
model: Optional[Type[PydanticModel]] = self._get_submodel(python_type, field_name)
if model:
self._relational_fields_index.append((field_name, model.__name__))
if field.null:
json_schema_extra["nullable"] = True
if field.null or field.default is not None:
model = Optional[model] # type: ignore
return model
return None
def _process_many_field_relation(
self,
field_name: str,
field: Union[BackwardFKRelation, ManyToManyFieldInstance],
) -> Optional[Type[List[Type[PydanticModel]]]]:
python_type = field.related_model
model = self._get_submodel(python_type, field_name)
if model:
self._relational_fields_index.append((field_name, model.__name__))
return List[model] # type: ignore
return None
def _process_data_field(
self,
field_name: str,
field: Field,
json_schema_extra: Dict[str, Any],
fconfig: Dict[str, Any],
) -> Optional[Any]:
annotation = self._annotations.get(field_name, None)
constraints = copy(field.constraints)
if "readOnly" in constraints:
json_schema_extra["readOnly"] = constraints["readOnly"]
del constraints["readOnly"]
fconfig.update(constraints)
python_type: Union[Type[Enum], Type[IntEnum], Type]
if isinstance(field, (IntEnumFieldInstance, CharEnumFieldInstance)):
python_type = field.enum_type
else:
python_type = getattr(field, "related_model", field.field_type)
ptype = python_type
if field.null:
json_schema_extra["nullable"] = True
if not field.pk and (
field_name in self._optional or field.default is not None or field.null
):
ptype = Optional[ptype]
if not (self._exclude_read_only and json_schema_extra.get("readOnly") is True):
return annotation or ptype
return None
def _process_computed_field(
self,
field: ComputedFieldDescription,
) -> Optional[Any]:
func = field.function
annotation = get_annotations(self._cls, func).get("return", None)
comment = _cleandoc(func)
if annotation is not None:
c_f = computed_field(return_type=annotation, description=comment)
ret = c_f(func)
return ret
return None
def _get_submodel(
self, _model: Optional["Type[Model]"], field_name: str
) -> Optional[Type[PydanticModel]]:
"""Get Pydantic model for the submodel"""
if _model:
new_stack = self._stack + ((self._cls, field_name, self.meta.max_recursion),)
# Get pydantic schema for the submodel
prefix_len = len(field_name) + 1
def get_fields_to_carry_on(field_tuple: Tuple[str, ...]) -> Tuple[str, ...]:
return tuple(
str(v[prefix_len:]) for v in field_tuple if v.startswith(field_name + ".")
)
pmodel = _pydantic_recursion_protector(
_model,
exclude=get_fields_to_carry_on(self.meta.exclude),
include=get_fields_to_carry_on(self.meta.include),
computed=get_fields_to_carry_on(self.meta.computed),
stack=new_stack,
allow_cycles=self.meta.allow_cycles,
sort_alphabetically=self.meta.sort_alphabetically,
)
else:
pmodel = None
# If the result is None it has been excluded and we need to exclude the field
if pmodel is None:
self.meta.exclude += (field_name,)
return pmodel
[docs]def pydantic_model_creator(
cls: "Type[Model]",
*,
name=None,
exclude: Optional[Tuple[str, ...]] = None,
include: Optional[Tuple[str, ...]] = None,
computed: Optional[Tuple[str, ...]] = None,
optional: Optional[Tuple[str, ...]] = None,
allow_cycles: Optional[bool] = None,
sort_alphabetically: Optional[bool] = None,
exclude_readonly: bool = False,
meta_override: Optional[Type] = None,
model_config: Optional[ConfigDict] = None,
validators: Optional[Dict[str, Any]] = None,
module: str = __name__,
) -> Type[PydanticModel]:
"""
Function to build `Pydantic Model <https://docs.pydantic.dev/latest/concepts/models/>`__ off Tortoise Model.
:param cls: The Tortoise Model
:param name: Specify a custom name explicitly, instead of a generated name.
:param exclude: Extra fields to exclude from the provided model.
:param include: Extra fields to include from the provided model.
:param computed: Extra computed fields to include from the provided model.
:param optional: Extra optional fields for the provided model.
:param allow_cycles: Do we allow any cycles in the generated model?
This is only useful for recursive/self-referential models.
A value of ``False`` (the default) will prevent any and all backtracking.
:param sort_alphabetically: Sort the parameters alphabetically instead of Field-definition order.
The default order would be:
* Field definition order +
* order of reverse relations (as discovered) +
* order of computed functions (as provided).
:param exclude_readonly: Build a subset model that excludes any readonly fields
:param meta_override: A PydanticMeta class to override model's values.
:param model_config: A custom config to use as pydantic config.
:param validators: A dictionary of methods that validate fields.
:param module: The name of the module that the model belongs to.
Note: Created pydantic model uses config_class parameter and PydanticMeta's
config_class as its Config class's bases(Only if provided!), but it
ignores ``fields`` config. pydantic_model_creator will generate fields by
include/exclude/computed parameters automatically.
"""
pmc = PydanticModelCreator(
cls=cls,
name=name,
exclude=exclude,
include=include,
computed=computed,
optional=optional,
allow_cycles=allow_cycles,
sort_alphabetically=sort_alphabetically,
exclude_readonly=exclude_readonly,
meta_override=meta_override,
model_config=model_config,
validators=validators,
module=module,
)
return pmc.create_pydantic_model()