Source code for tortoise.fields.relational

from typing import (
    TYPE_CHECKING,
    Any,
    AsyncGenerator,
    Generator,
    Generic,
    Iterator,
    List,
    Optional,
    Tuple,
    Type,
    TypeVar,
    Union,
    overload,
)

from pypika import Table
from typing_extensions import Literal

from tortoise.exceptions import ConfigurationError, NoValuesFetched, OperationalError
from tortoise.fields.base import CASCADE, SET_NULL, Field, OnDelete

if TYPE_CHECKING:  # pragma: nocoverage
    from tortoise.backends.base.client import BaseDBAsyncClient
    from tortoise.models import Model
    from tortoise.queryset import Q, QuerySet

MODEL = TypeVar("MODEL", bound="Model")


class _NoneAwaitable:
    __slots__ = ()

    def __await__(self) -> Generator[None, None, None]:
        yield None

    def __bool__(self) -> bool:
        return False


NoneAwaitable = _NoneAwaitable()


[docs]class ReverseRelation(Generic[MODEL]): """ Relation container for :func:`.ForeignKeyField`. """ def __init__( self, remote_model: Type[MODEL], relation_field: str, instance: "Model", from_field: str, ) -> None: self.remote_model = remote_model self.relation_field = relation_field self.instance = instance self.from_field = from_field self._fetched = False self._custom_query = False self.related_objects: List[MODEL] = [] @property def _query(self) -> "QuerySet[MODEL]": if not self.instance._saved_in_db: raise OperationalError( "This objects hasn't been instanced, call .save() before calling related queries" ) return self.remote_model.filter( **{self.relation_field: getattr(self.instance, self.from_field)} ) def __contains__(self, item: Any) -> bool: self._raise_if_not_fetched() return item in self.related_objects def __iter__(self) -> "Iterator[MODEL]": self._raise_if_not_fetched() return self.related_objects.__iter__() def __len__(self) -> int: self._raise_if_not_fetched() return len(self.related_objects) def __bool__(self) -> bool: self._raise_if_not_fetched() return bool(self.related_objects) def __getitem__(self, item: int) -> MODEL: self._raise_if_not_fetched() return self.related_objects[item] def __await__(self) -> Generator[Any, None, List[MODEL]]: return self._query.__await__() async def __aiter__(self) -> AsyncGenerator[Any, MODEL]: if not self._fetched: self._set_result_for_query(await self) for val in self.related_objects: yield val
[docs] def filter(self, *args: "Q", **kwargs: Any) -> "QuerySet[MODEL]": """ Returns a QuerySet with related elements filtered by args/kwargs. """ return self._query.filter(*args, **kwargs)
[docs] def all(self) -> "QuerySet[MODEL]": """ Returns a QuerySet with all related elements. """ return self._query
[docs] def order_by(self, *orderings: str) -> "QuerySet[MODEL]": """ Returns a QuerySet related elements in order. """ return self._query.order_by(*orderings)
[docs] def limit(self, limit: int) -> "QuerySet[MODEL]": """ Returns a QuerySet with at most «limit» related elements. """ return self._query.limit(limit)
[docs] def offset(self, offset: int) -> "QuerySet[MODEL]": """ Returns a QuerySet with all related elements offset by «offset». """ return self._query.offset(offset)
def _set_result_for_query(self, sequence: List[MODEL], attr: Optional[str] = None) -> None: self._fetched = True self.related_objects = sequence if attr: setattr(self.instance, attr, sequence) def _raise_if_not_fetched(self) -> None: if not self._fetched: raise NoValuesFetched( "No values were fetched for this relation, first use .fetch_related()" )
[docs]class ManyToManyRelation(ReverseRelation[MODEL]): """ Many-to-many relation container for :func:`.ManyToManyField`. """ def __init__(self, instance: "Model", m2m_field: "ManyToManyFieldInstance[MODEL]") -> None: super().__init__(m2m_field.related_model, m2m_field.related_name, instance, "pk") self.field = m2m_field self.instance = instance
[docs] async def add(self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: """ Adds one or more of ``instances`` to the relation. If it is already added, it will be silently ignored. :raises OperationalError: If Object to add is not saved. """ if not instances: return if not self.instance._saved_in_db: raise OperationalError(f"You should first call .save() on {self.instance}") db = using_db or self.remote_model._meta.db pk_formatting_func = type(self.instance)._meta.pk.to_db_value related_pk_formatting_func = type(instances[0])._meta.pk.to_db_value pk_b = pk_formatting_func(self.instance.pk, self.instance) pks_f: list = [] for instance_to_add in instances: if not instance_to_add._saved_in_db: raise OperationalError(f"You should first call .save() on {instance_to_add}") pk_f = related_pk_formatting_func(instance_to_add.pk, instance_to_add) pks_f.append(pk_f) through_table = Table(self.field.through) backward_key, forward_key = self.field.backward_key, self.field.forward_key backward_field, forward_field = through_table[backward_key], through_table[forward_key] select_query = ( db.query_class.from_(through_table).where(backward_field == pk_b).select(forward_key) ) criterion = forward_field == pks_f[0] if len(pks_f) == 1 else forward_field.isin(pks_f) select_query = select_query.where(criterion) _, already_existing_relations_raw = await db.execute_query( *select_query.get_parameterized_sql() ) already_existing_forward_pks = { related_pk_formatting_func(r[forward_key], self.instance) for r in already_existing_relations_raw } if pks_f_to_insert := set(pks_f) - already_existing_forward_pks: query = db.query_class.into(through_table).columns(forward_field, backward_field) for pk_f in pks_f_to_insert: query = query.insert(pk_f, pk_b) await db.execute_query(*query.get_parameterized_sql())
[docs] async def clear(self, using_db: "Optional[BaseDBAsyncClient]" = None) -> None: """ Clears ALL relations. """ await self._remove_or_clear(using_db=using_db)
[docs] async def remove( self, *instances: MODEL, using_db: "Optional[BaseDBAsyncClient]" = None ) -> None: """ Removes one or more of ``instances`` from the relation. :raises OperationalError: remove() was called with no instances. """ if not instances: raise OperationalError("remove() called on no instances") await self._remove_or_clear(instances, using_db)
async def _remove_or_clear( self, instances: Optional[Tuple[MODEL, ...]] = None, using_db: "Optional[BaseDBAsyncClient]" = None, ) -> None: db = using_db or self.remote_model._meta.db through_table = Table(self.field.through) pk_formatting_func = type(self.instance)._meta.pk.to_db_value condition = through_table[self.field.backward_key] == pk_formatting_func( self.instance.pk, self.instance ) if instances: related_pk_formatting_func = type(instances[0])._meta.pk.to_db_value if len(instances) == 1: condition &= through_table[self.field.forward_key] == related_pk_formatting_func( instances[0].pk, instances[0] ) else: condition &= through_table[self.field.forward_key].isin( [related_pk_formatting_func(i.pk, i) for i in instances] ) query = db.query_class.from_(through_table).where(condition).delete() await db.execute_query(*query.get_parameterized_sql())
class RelationalField(Field[MODEL]): has_db_field = False def __init__( self, related_model: "Type[MODEL]", to_field: Optional[str] = None, db_constraint: bool = True, **kwargs: Any, ) -> None: super().__init__(**kwargs) self.related_model: "Type[MODEL]" = related_model self.to_field: str = to_field # type: ignore self.to_field_instance: Field = None # type: ignore self.db_constraint = db_constraint if TYPE_CHECKING: @overload def __get__(self, instance: None, owner: Type["Model"]) -> "RelationalField[MODEL]": ... @overload def __get__(self, instance: "Model", owner: Type["Model"]) -> MODEL: ... def __get__( self, instance: Optional["Model"], owner: Type["Model"] ) -> "RelationalField[MODEL] | MODEL": ... def __set__(self, instance: "Model", value: MODEL) -> None: ... def describe(self, serializable: bool) -> dict: desc = super().describe(serializable) desc["db_constraint"] = self.db_constraint del desc["db_column"] return desc @classmethod def validate_model_name(cls, model_name: str) -> None: if len(model_name.split(".")) != 2: field_type = cls.__name__.replace("Instance", "") raise ConfigurationError(f'{field_type} accepts model name in format "app.Model"') class ForeignKeyFieldInstance(RelationalField[MODEL]): def __init__( self, model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, **kwargs: Any, ) -> None: super().__init__(None, **kwargs) # type: ignore self.validate_model_name(model_name) self.model_name = model_name self.related_name = related_name if on_delete not in set(OnDelete): raise ConfigurationError( "on_delete can only be CASCADE, RESTRICT, SET_NULL, SET_DEFAULT or NO_ACTION" ) if on_delete == SET_NULL and not bool(kwargs.get("null")): raise ConfigurationError("If on_delete is SET_NULL, then field must have null=True set") self.on_delete = on_delete def describe(self, serializable: bool) -> dict: desc = super().describe(serializable) desc["raw_field"] = self.source_field desc["on_delete"] = str(self.on_delete) return desc class BackwardFKRelation(RelationalField[MODEL]): def __init__( self, field_type: "Type[MODEL]", relation_field: str, relation_source_field: str, null: bool, description: Optional[str], **kwargs: Any, ) -> None: super().__init__(field_type, null=null, **kwargs) self.relation_field: str = relation_field self.relation_source_field: str = relation_source_field self.description: Optional[str] = description class OneToOneFieldInstance(ForeignKeyFieldInstance[MODEL]): def __init__( self, model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, **kwargs: Any, ) -> None: self.validate_model_name(model_name) super().__init__(model_name, related_name, on_delete, unique=True, **kwargs) class BackwardOneToOneRelation(BackwardFKRelation[MODEL]): pass class ManyToManyFieldInstance(RelationalField[MODEL]): field_type = ManyToManyRelation def __init__( self, model_name: str, through: Optional[str] = None, forward_key: Optional[str] = None, backward_key: str = "", related_name: str = "", on_delete: OnDelete = CASCADE, field_type: "Type[MODEL]" = None, # type: ignore create_unique_index: bool = True, **kwargs: Any, ) -> None: # TODO: rename through to through_table # TODO: add through to use a Model super().__init__(field_type, **kwargs) self.validate_model_name(model_name) self.model_name: str = model_name self.related_name: str = related_name self.forward_key: str = forward_key or f"{model_name.split('.')[1].lower()}_id" self.backward_key: str = backward_key self.through: str = through # type: ignore self._generated: bool = False self.on_delete = on_delete self.create_unique_index = create_unique_index def describe(self, serializable: bool) -> dict: desc = super().describe(serializable) desc["model_name"] = self.model_name desc["related_name"] = self.related_name desc["forward_key"] = self.forward_key desc["backward_key"] = self.backward_key desc["through"] = self.through desc["on_delete"] = str(self.on_delete) desc["_generated"] = self._generated return desc @overload def OneToOneField( model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, *, null: Literal[True], **kwargs: Any, ) -> "OneToOneNullableRelation[MODEL]": ... @overload def OneToOneField( model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, **kwargs: Any, ) -> "OneToOneRelation[MODEL]": ...
[docs]def OneToOneField( model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: bool = False, **kwargs: Any, ) -> "OneToOneRelation[MODEL] | OneToOneNullableRelation[MODEL]": """ OneToOne relation field. This field represents a foreign key relation to another model. See :ref:`one_to_one` for usage information. You must provide the following: ``model_name``: The name of the related model in a :samp:`'{app}.{model}'` format. The following is optional: ``related_name``: The attribute name on the related model to reverse resolve the foreign key. ``on_delete``: One of: ``field.CASCADE``: Indicate that the model should be cascade deleted if related model gets deleted. ``field.RESTRICT``: Indicate that the related model delete will be restricted as long as a foreign key points to it. ``field.SET_NULL``: Resets the field to NULL in case the related model gets deleted. Can only be set if field has ``null=True`` set. ``field.SET_DEFAULT``: Resets the field to ``default`` value in case the related model gets deleted. Can only be set is field has a ``default`` set. ``field.NO_ACTION``: Take no action. ``to_field``: The attribute name on the related model to establish foreign key relationship. If not set, pk is used ``db_constraint``: Controls whether or not a constraint should be created in the database for this foreign key. The default is True, and that’s almost certainly what you want; setting this to False can be very bad for data integrity. """ return OneToOneFieldInstance( model_name, related_name, on_delete, db_constraint=db_constraint, null=null, **kwargs )
@overload def ForeignKeyField( model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, *, null: Literal[True], **kwargs: Any, ) -> "ForeignKeyNullableRelation[MODEL]": ... @overload def ForeignKeyField( model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: Literal[False] = False, **kwargs: Any, ) -> "ForeignKeyRelation[MODEL]": ...
[docs]def ForeignKeyField( model_name: str, related_name: Union[Optional[str], Literal[False]] = None, on_delete: OnDelete = CASCADE, db_constraint: bool = True, null: bool = False, **kwargs: Any, ) -> "ForeignKeyRelation[MODEL] | ForeignKeyNullableRelation[MODEL]": """ ForeignKey relation field. This field represents a foreign key relation to another model. See :ref:`foreign_key` for usage information. You must provide the following: ``model_name``: The name of the related model in a :samp:`'{app}.{model}'` format. The following is optional: ``related_name``: The attribute name on the related model to reverse resolve the foreign key. ``on_delete``: One of: ``field.CASCADE``: Indicate that the model should be cascade deleted if related model gets deleted. ``field.RESTRICT``: Indicate that the related model delete will be restricted as long as a foreign key points to it. ``field.SET_NULL``: Resets the field to NULL in case the related model gets deleted. Can only be set if field has ``null=True`` set. ``field.SET_DEFAULT``: Resets the field to ``default`` value in case the related model gets deleted. Can only be set is field has a ``default`` set. ``field.NO_ACTION``: Take no action. ``to_field``: The attribute name on the related model to establish foreign key relationship. If not set, pk is used ``db_constraint``: Controls whether or not a constraint should be created in the database for this foreign key. The default is True, and that’s almost certainly what you want; setting this to False can be very bad for data integrity. """ return ForeignKeyFieldInstance( model_name, related_name, on_delete, db_constraint=db_constraint, null=null, **kwargs )
[docs]def ManyToManyField( model_name: str, through: Optional[str] = None, forward_key: Optional[str] = None, backward_key: str = "", related_name: str = "", on_delete: OnDelete = CASCADE, db_constraint: bool = True, create_unique_index: bool = True, **kwargs: Any, ) -> "ManyToManyRelation[Any]": """ ManyToMany relation field. This field represents a many-to-many between this model and another model. See :ref:`many_to_many` for usage information. You must provide the following: ``model_name``: The name of the related model in a :samp:`'{app}.{model}'` format. The following is optional: ``through``: The DB table that represents the through table. The default is normally safe. ``forward_key``: The forward lookup key on the through table. The default is normally safe. ``backward_key``: The backward lookup key on the through table. The default is normally safe. ``related_name``: The attribute name on the related model to reverse resolve the many to many. ``db_constraint``: Controls whether or not a constraint should be created in the database for this foreign key. The default is True, and that’s almost certainly what you want; setting this to False can be very bad for data integrity. ``on_delete``: One of: ``field.CASCADE``: Indicate that the model should be cascade deleted if related model gets deleted. ``field.RESTRICT``: Indicate that the related model delete will be restricted as long as a foreign key points to it. ``field.SET_NULL``: Resets the field to NULL in case the related model gets deleted. Can only be set if field has ``null=True`` set. ``field.SET_DEFAULT``: Resets the field to ``default`` value in case the related model gets deleted. Can only be set is field has a ``default`` set. ``field.NO_ACTION``: Take no action. ``create_unique_index``: Controls whether or not a unique index should be created in the database to speed up select queries. The default is True. If you want to allow repeat records, set this to False. """ return ManyToManyFieldInstance( # type: ignore model_name, through, forward_key, backward_key, related_name, on_delete=on_delete, db_constraint=db_constraint, create_unique_index=create_unique_index, **kwargs, )
OneToOneNullableRelation = Optional[OneToOneFieldInstance[MODEL]] """ Type hint for the result of accessing the :func:`.OneToOneField` field in the model when obtained model can be nullable. """ OneToOneRelation = OneToOneFieldInstance[MODEL] """ Type hint for the result of accessing the :func:`.OneToOneField` field in the model. """ ForeignKeyNullableRelation = Optional[ForeignKeyFieldInstance[MODEL]] """ Type hint for the result of accessing the :func:`.ForeignKeyField` field in the model when obtained model can be nullable. """ ForeignKeyRelation = ForeignKeyFieldInstance[MODEL] """ Type hint for the result of accessing the :func:`.ForeignKeyField` field in the model. """