import types
from copy import copy
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Generator,
Generic,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from pypika import JoinType, Order, Table
from pypika.analytics import Count
from pypika.functions import Cast
from pypika.queries import QueryBuilder
from pypika.terms import Case, Field, Term, ValueWrapper
from typing_extensions import Literal, Protocol
from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities
from tortoise.exceptions import (
DoesNotExist,
FieldError,
IntegrityError,
MultipleObjectsReturned,
ParamsError,
)
from tortoise.expressions import Expression, F, Q, RawSQL
from tortoise.fields.relational import (
ForeignKeyFieldInstance,
OneToOneFieldInstance,
RelationalField,
)
from tortoise.filters import FilterInfoDict
from tortoise.functions import Function
from tortoise.query_utils import Prefetch, QueryModifier, _get_joins_for_related_field
from tortoise.router import router
from tortoise.utils import chunk
# Empty placeholder - Should never be edited.
QUERY: QueryBuilder = QueryBuilder()
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
MODEL = TypeVar("MODEL", bound="Model")
T_co = TypeVar("T_co", covariant=True)
SINGLE = TypeVar("SINGLE", bound=bool)
[docs]class QuerySetSingle(Protocol[T_co]):
"""
Awaiting on this will resolve a single instance of the Model object, and not a sequence.
"""
# pylint: disable=W0104
def __await__(self) -> Generator[Any, None, T_co]: ... # pragma: nocoverage
def prefetch_related(
self, *args: Union[str, Prefetch]
) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def select_related(self, *args: str) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def annotate(self, **kwargs: Function) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def only(self, *fields_for_select: str) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def values_list(
self, *fields_: str, flat: bool = False
) -> "ValuesListQuery[Literal[True]]": ... # pragma: nocoverage
def values(
self, *args: str, **kwargs: str
) -> "ValuesQuery[Literal[True]]": ... # pragma: nocoverage
class AwaitableQuery(Generic[MODEL]):
__slots__ = (
"_joined_tables",
"query",
"model",
"_db",
"capabilities",
"_annotations",
)
def __init__(self, model: Type[MODEL]) -> None:
self._joined_tables: List[Table] = []
self.model: "Type[MODEL]" = model
self.query: QueryBuilder = QUERY
self._db: BaseDBAsyncClient = None # type: ignore
self.capabilities: Capabilities = model._meta.db.capabilities
self._annotations: Dict[str, Expression] = {}
def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient:
"""
Return the connection that will be used if this query is executed now.
:return: BaseDBAsyncClient:
"""
if self._db:
return self._db
if for_write:
db = router.db_for_write(self.model)
else:
db = router.db_for_read(self.model)
return db or self.model._meta.db
def resolve_filters(
self,
model: "Type[Model]",
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
) -> None:
"""
Builds the common filters for a QuerySet.
:param model: The Model this queryset is based on.
:param q_objects: The Q expressions to apply.
:param annotations: Extra annotations to add.
:param custom_filters: Pre-resolved filters to be passed through.
"""
has_aggregate = self._resolve_annotate(annotations)
modifier = QueryModifier()
for node in q_objects:
node._annotations = annotations
node._custom_filters = custom_filters
modifier &= node.resolve(model, model._meta.basetable)
where_criterion, joins, having_criterion = modifier.get_query_modifiers()
for join in joins:
if join[0] not in self._joined_tables:
self.query = self.query.join(join[0], how=JoinType.left_outer).on(join[1])
self._joined_tables.append(join[0])
self.query._wheres = where_criterion
self.query._havings = having_criterion
if has_aggregate and (self._joined_tables or having_criterion or self.query._orderbys):
self.query = self.query.groupby(
*[self.model._meta.basetable[field] for field in self.model._meta.db_fields]
)
def _join_table_by_field(
self, table: Table, related_field_name: str, related_field: RelationalField
) -> Table:
joins = _get_joins_for_related_field(table, related_field, related_field_name)
for join in joins:
if join[0] not in self._joined_tables:
self.query = self.query.join(join[0], how=JoinType.left_outer).on(join[1])
self._joined_tables.append(join[0])
return joins[-1][0]
@staticmethod
def _resolve_ordering_string(ordering: str) -> Tuple[str, Order]:
order_type = Order.asc
if ordering[0] == "-":
field_name = ordering[1:]
order_type = Order.desc
else:
field_name = ordering
return field_name, order_type
def resolve_ordering(
self,
model: "Type[Model]",
table: Table,
orderings: Iterable[Tuple[str, str]],
annotations: Dict[str, Any],
) -> None:
"""
Applies standard ordering to QuerySet.
:param model: The Model this queryset is based on.
:param table: ``pypika.Table`` to keep track of the virtual SQL table
(to allow self referential joins)
:param orderings: What columns/order to order by
:param annotations: Annotations that may be ordered on
:raises FieldError: If a field provided does not exist in model.
"""
# Do not apply default ordering for annotated queries to not mess them up
if not orderings and self.model._meta.ordering and not annotations:
orderings = self.model._meta.ordering
for ordering in orderings:
field_name = ordering[0]
if field_name in model._meta.fetch_fields:
raise FieldError(
"Filtering by relation is not possible. Filter by nested field of related model"
)
related_field_name, __, forwarded = field_name.partition("__")
if related_field_name in model._meta.fetch_fields:
related_field = cast(RelationalField, model._meta.fields_map[related_field_name])
related_table = self._join_table_by_field(table, related_field_name, related_field)
self.resolve_ordering(
related_field.related_model,
related_table,
[(forwarded, ordering[1])],
{},
)
elif field_name in annotations:
annotation = annotations[field_name]
if isinstance(annotation, Term):
self.query = self.query.orderby(annotation, order=ordering[1])
else:
annotation_info = annotation.resolve(self.model, table)
self.query = self.query.orderby(annotation_info["field"], order=ordering[1])
else:
field_object = model._meta.fields_map.get(field_name)
if not field_object:
raise FieldError(f"Unknown field {field_name} for model {model.__name__}")
field_name = field_object.source_field or field_name
field = table[field_name]
func = field_object.get_for_dialect(
model._meta.db.capabilities.dialect, "function_cast"
)
if func:
field = func(field_object, field)
self.query = self.query.orderby(field, order=ordering[1])
def _resolve_annotate(self, extra_annotations: Dict[str, Any]) -> bool:
if not self._annotations and not extra_annotations:
return False
table = self.model._meta.basetable
all_annotations = {**self._annotations, **extra_annotations}
annotation_info = {}
for key, annotation in all_annotations.items():
if isinstance(annotation, Term):
annotation_info[key] = {"joins": [], "field": annotation}
else:
annotation_info[key] = annotation.resolve(self.model, table)
for key, info in annotation_info.items():
for join in info["joins"]:
self._join_table_by_field(*join)
if key in self._annotations:
self.query._select_other(info["field"].as_(key))
return any(info["field"].is_aggregate for info in annotation_info.values())
def sql(self, **kwargs) -> str:
"""Return the actual SQL."""
return self.as_query().get_sql(**kwargs)
def as_query(self) -> QueryBuilder:
"""Return the actual query."""
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return self.query
def _make_query(self) -> None:
raise NotImplementedError() # pragma: nocoverage
async def _execute(self) -> Any:
raise NotImplementedError() # pragma: nocoverage
[docs]class QuerySet(AwaitableQuery[MODEL]):
__slots__ = (
"fields",
"_prefetch_map",
"_prefetch_queries",
"_single",
"_raise_does_not_exist",
"_db",
"_limit",
"_offset",
"_fields_for_select",
"_filter_kwargs",
"_orderings",
"_q_objects",
"_distinct",
"_having",
"_custom_filters",
"_group_bys",
"_select_for_update",
"_select_for_update_nowait",
"_select_for_update_skip_locked",
"_select_for_update_of",
"_select_related",
"_select_related_idx",
"_use_indexes",
"_force_indexes",
)
def __init__(self, model: Type[MODEL]) -> None:
super().__init__(model)
self.fields: Set[str] = model._meta.db_fields
self._prefetch_map: Dict[str, Set[Union[str, Prefetch]]] = {}
self._prefetch_queries: Dict[str, List[Tuple[Optional[str], QuerySet]]] = {}
self._single: bool = False
self._raise_does_not_exist: bool = False
self._limit: Optional[int] = None
self._offset: Optional[int] = None
self._filter_kwargs: Dict[str, Any] = {}
self._orderings: List[Tuple[str, Any]] = []
self._q_objects: List[Q] = []
self._distinct: bool = False
self._having: Dict[str, Any] = {}
self._custom_filters: Dict[str, FilterInfoDict] = {}
self._fields_for_select: Tuple[str, ...] = ()
self._group_bys: Tuple[str, ...] = ()
self._select_for_update: bool = False
self._select_for_update_nowait: bool = False
self._select_for_update_skip_locked: bool = False
self._select_for_update_of: Set[str] = set()
self._select_related: Set[str] = set()
self._select_related_idx: List[
Tuple["Type[Model]", int, str, "Type[Model]", Iterable[Optional[str]]]
] = [] # format with: model,idx,model_name,parent_model
self._force_indexes: Set[str] = set()
self._use_indexes: Set[str] = set()
def _clone(self) -> "QuerySet[MODEL]":
queryset = self.__class__.__new__(self.__class__)
queryset.fields = self.fields
queryset.model = self.model
queryset.query = self.query
queryset.capabilities = self.capabilities
queryset._prefetch_map = copy(self._prefetch_map)
queryset._prefetch_queries = copy(self._prefetch_queries)
queryset._single = self._single
queryset._raise_does_not_exist = self._raise_does_not_exist
queryset._db = self._db
queryset._limit = self._limit
queryset._offset = self._offset
queryset._fields_for_select = self._fields_for_select
queryset._filter_kwargs = copy(self._filter_kwargs)
queryset._orderings = copy(self._orderings)
queryset._joined_tables = copy(self._joined_tables)
queryset._q_objects = copy(self._q_objects)
queryset._distinct = self._distinct
queryset._annotations = copy(self._annotations)
queryset._having = copy(self._having)
queryset._custom_filters = copy(self._custom_filters)
queryset._group_bys = copy(self._group_bys)
queryset._select_for_update = self._select_for_update
queryset._select_for_update_nowait = self._select_for_update_nowait
queryset._select_for_update_skip_locked = self._select_for_update_skip_locked
queryset._select_for_update_of = self._select_for_update_of
queryset._select_related = self._select_related
queryset._select_related_idx = self._select_related_idx
queryset._force_indexes = self._force_indexes
queryset._use_indexes = self._use_indexes
return queryset
def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet[MODEL]":
queryset = self._clone()
for arg in args:
if not isinstance(arg, Q):
raise TypeError("expected Q objects as args")
if negate:
queryset._q_objects.append(~arg)
else:
queryset._q_objects.append(arg)
for key, value in kwargs.items():
if negate:
queryset._q_objects.append(~Q(**{key: value}))
else:
queryset._q_objects.append(Q(**{key: value}))
return queryset
[docs] def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]":
"""
Filters QuerySet by given kwargs. You can filter by related objects like this:
.. code-block:: python3
Team.filter(events__tournament__name='Test')
You can also pass Q objects to filters as args.
"""
return self._filter_or_exclude(negate=False, *args, **kwargs)
[docs] def exclude(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]":
"""
Same as .filter(), but with appends all args with NOT
"""
return self._filter_or_exclude(negate=True, *args, **kwargs)
[docs] def order_by(self, *orderings: str) -> "QuerySet[MODEL]":
"""
Accept args to filter by in format like this:
.. code-block:: python3
.order_by('name', '-tournament__name')
Supports ordering by related models too.
A '-' before the name will result in descending sort order, default is ascending.
:raises FieldError: If unknown field has been provided.
"""
queryset = self._clone()
new_ordering = []
for ordering in orderings:
field_name, order_type = self._resolve_ordering_string(ordering)
if not (
field_name.split("__")[0] in self.model._meta.fields
or field_name in self._annotations
):
raise FieldError(f"Unknown field {field_name} for model {self.model.__name__}")
new_ordering.append((field_name, order_type))
queryset._orderings = new_ordering
return queryset
[docs] def limit(self, limit: int) -> "QuerySet[MODEL]":
"""
Limits QuerySet to given length.
:raises ParamsError: Limit should be non-negative number.
"""
if limit < 0:
raise ParamsError("Limit should be non-negative number")
queryset = self._clone()
queryset._limit = limit
return queryset
[docs] def offset(self, offset: int) -> "QuerySet[MODEL]":
"""
Query offset for QuerySet.
:raises ParamsError: Offset should be non-negative number.
"""
if offset < 0:
raise ParamsError("Offset should be non-negative number")
queryset = self._clone()
queryset._offset = offset
if self.capabilities.requires_limit and queryset._limit is None:
queryset._limit = 1000000
return queryset
[docs] def __getitem__(self, key: slice) -> "QuerySet[MODEL]":
"""
Query offset and limit for Queryset.
:raises ParamsError: QuerySet indices must be slices.
:raises ParamsError: Slice steps should be 1 or None.
:raises ParamsError: Slice start should be non-negative number or None.
:raises ParamsError: Slice stop should be non-negative number greater that slice start,
or None.
"""
if not isinstance(key, slice):
raise ParamsError("QuerySet indices must be slices.")
if not (key.step is None or (isinstance(key.step, int) and key.step == 1)):
raise ParamsError("Slice steps should be 1 or None.")
start = key.start if key.start is not None else 0
if not isinstance(start, int) or start < 0:
raise ParamsError("Slice start should be non-negative number or None.")
if key.stop is not None and (not isinstance(key.stop, int) or key.stop <= start):
raise ParamsError(
"Slice stop should be non-negative number greater that slice start, or None.",
)
queryset = self.offset(start)
if key.stop:
queryset = queryset.limit(key.stop - start)
return queryset
[docs] def distinct(self) -> "QuerySet[MODEL]":
"""
Make QuerySet distinct.
Only makes sense in combination with a ``.values()`` or ``.values_list()`` as it
precedes all the fetched fields with a distinct.
"""
queryset = self._clone()
queryset._distinct = True
return queryset
[docs] def select_for_update(
self, nowait: bool = False, skip_locked: bool = False, of: Tuple[str, ...] = ()
) -> "QuerySet[MODEL]":
"""
Make QuerySet select for update.
Returns a queryset that will lock rows until the end of the transaction,
generating a SELECT ... FOR UPDATE SQL statement on supported databases.
"""
if self.capabilities.support_for_update:
queryset = self._clone()
queryset._select_for_update = True
queryset._select_for_update_nowait = nowait
queryset._select_for_update_skip_locked = skip_locked
queryset._select_for_update_of = set(of)
return queryset
return self
[docs] def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL]":
"""
Annotate result with aggregation or function result.
:raises TypeError: Value of kwarg is expected to be a ``Function`` instance.
"""
from tortoise.models import get_filters_for_field
queryset = self._clone()
for key, annotation in kwargs.items():
# if not isinstance(annotation, (Function, Term)):
# raise TypeError("value is expected to be Function/Term instance")
queryset._annotations[key] = annotation
queryset._custom_filters.update(get_filters_for_field(key, None, key))
return queryset
[docs] def group_by(self, *fields: str) -> "QuerySet[MODEL]":
"""
Make QuerySet returns list of dict or tuple with group by.
Must call before .values() or .values_list()
"""
queryset = self._clone()
queryset._group_bys = fields
return queryset
[docs] def values_list(self, *fields_: str, flat: bool = False) -> "ValuesListQuery[Literal[False]]":
"""
Make QuerySet returns list of tuples for given args instead of objects.
If call after `.get()`, `.get_or_none()` or `.first()` return tuples for given args instead of object.
If ```flat=True`` and only one arg is passed can return flat list or just scalar.
If no arguments are passed it will default to a tuple containing all fields
in order of declaration.
"""
fields_for_select_list = fields_ or [
field for field in self.model._meta.fields_map if field in self.model._meta.db_fields
] + list(self._annotations.keys())
return ValuesListQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
single=self._single,
raise_does_not_exist=self._raise_does_not_exist,
flat=flat,
fields_for_select_list=fields_for_select_list,
distinct=self._distinct,
limit=self._limit,
offset=self._offset,
orderings=self._orderings,
annotations=self._annotations,
custom_filters=self._custom_filters,
group_bys=self._group_bys,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def values(self, *args: str, **kwargs: str) -> "ValuesQuery[Literal[False]]":
"""
Make QuerySet return dicts instead of objects.
If call after `.get()`, `.get_or_none()` or `.first()` return dict instead of object.
Can pass names of fields to fetch, or as a ``field_name='name_in_dict'`` kwarg.
If no arguments are passed it will default to a dict containing all fields.
:raises FieldError: If duplicate key has been provided.
"""
if args or kwargs:
fields_for_select: Dict[str, str] = {}
for field in args:
if field in fields_for_select:
raise FieldError(f"Duplicate key {field}")
fields_for_select[field] = field
for return_as, field in kwargs.items():
if return_as in fields_for_select:
raise FieldError(f"Duplicate key {return_as}")
fields_for_select[return_as] = field
else:
_fields = [
field
for field in self.model._meta.fields_map.keys()
if field in self.model._meta.fields_db_projection.keys()
] + list(self._annotations.keys())
fields_for_select = {field: field for field in _fields}
return ValuesQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
single=self._single,
raise_does_not_exist=self._raise_does_not_exist,
fields_for_select=fields_for_select,
distinct=self._distinct,
limit=self._limit,
offset=self._offset,
orderings=self._orderings,
annotations=self._annotations,
custom_filters=self._custom_filters,
group_bys=self._group_bys,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def delete(self) -> "DeleteQuery":
"""
Delete all objects in QuerySet.
"""
return DeleteQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
orderings=self._orderings,
)
[docs] def update(self, **kwargs: Any) -> "UpdateQuery":
"""
Update all objects in QuerySet with given kwargs.
.. admonition: Example:
.. code-block:: py3
await Employee.filter(occupation='developer').update(salary=5000)
Will instead of returning a resultset, update the data in the DB itself.
"""
return UpdateQuery(
db=self._db,
model=self.model,
update_kwargs=kwargs,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
orderings=self._orderings,
)
[docs] def count(self) -> "CountQuery":
"""
Return count of objects in queryset instead of objects.
"""
return CountQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
offset=self._offset,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def exists(self) -> "ExistsQuery":
"""
Return True/False whether queryset exists.
"""
return ExistsQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def all(self) -> "QuerySet[MODEL]":
"""
Return the whole QuerySet.
Essentially a no-op except as the only operation.
"""
return self._clone()
[docs] def raw(self, sql: str) -> "RawSQLQuery":
"""
Return the QuerySet from raw SQL
"""
return RawSQLQuery(model=self.model, db=self._db, sql=sql)
[docs] def first(self) -> QuerySetSingle[Optional[MODEL]]:
"""
Limit queryset to one object and return one object instead of list.
"""
queryset = self._clone()
queryset._limit = 1
queryset._single = True
return queryset # type: ignore
[docs] def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]:
"""
Fetch exactly one object matching the parameters.
"""
queryset = self.filter(*args, **kwargs)
queryset._limit = 2
queryset._single = True
queryset._raise_does_not_exist = True
return queryset # type: ignore
[docs] async def in_bulk(
self, id_list: Iterable[Union[str, int]], field_name: str
) -> Dict[str, MODEL]:
"""
Return a dictionary mapping each of the given IDs to the object with
that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
:param id_list: A list of field values
:param field_name: Must be a unique field
"""
objs = await self.filter(**{f"{field_name}__in": id_list})
return {getattr(obj, field_name): obj for obj in objs}
[docs] def bulk_create(
self,
objects: Iterable[MODEL],
batch_size: Optional[int] = None,
ignore_conflicts: bool = False,
update_fields: Optional[Iterable[str]] = None,
on_conflict: Optional[Iterable[str]] = None,
) -> "BulkCreateQuery[MODEL]":
"""
This method inserts the provided list of objects into the database in an efficient manner
(generally only 1 query, no matter how many objects there are).
:param on_conflict: On conflict index name
:param update_fields: Update fields when conflicts
:param ignore_conflicts: Ignore conflicts when inserting
:param objects: List of objects to bulk create
:param batch_size: How many objects are created in a single query
:raises ValueError: If params do not meet specifications
"""
if ignore_conflicts and update_fields:
raise ValueError(
"ignore_conflicts and update_fields are mutually exclusive.",
)
if not ignore_conflicts:
if (update_fields and not on_conflict) or (on_conflict and not update_fields):
raise ValueError("update_fields and on_conflict need set in same time.")
return BulkCreateQuery(
db=self._db,
model=self.model,
objects=objects,
batch_size=batch_size,
ignore_conflicts=ignore_conflicts,
update_fields=update_fields,
on_conflict=on_conflict,
)
[docs] def bulk_update(
self,
objects: Iterable[MODEL],
fields: Iterable[str],
batch_size: Optional[int] = None,
) -> "BulkUpdateQuery[MODEL]":
"""
Update the given fields in each of the given objects in the database.
:param objects: List of objects to bulk create
:param fields: The fields to update
:param batch_size: How many objects are created in a single query
:raises ValueError: If objects have no primary key set
"""
if any(obj.pk is None for obj in objects):
raise ValueError("All bulk_update() objects must have a primary key set.")
return BulkUpdateQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
orderings=self._orderings,
objects=objects,
fields=fields,
batch_size=batch_size,
)
[docs] def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL]]:
"""
Fetch exactly one object matching the parameters.
"""
queryset = self.filter(*args, **kwargs)
queryset._limit = 2
queryset._single = True
return queryset # type: ignore
[docs] def only(self, *fields_for_select: str) -> "QuerySet[MODEL]":
"""
Fetch ONLY the specified fields to create a partial model.
Persisting changes on the model is allowed only when:
* All the fields you want to update is specified in ``<model>.save(update_fields=[...])``
* You included the Model primary key in the `.only(...)``
To protect against common mistakes we ensure that errors get raised:
* If you access a field that is not specified, you will get an ``AttributeError``.
* If you do a ``<model>.save()`` a ``IncompleteInstanceError`` will be raised as the model is, as requested, incomplete.
* If you do a ``<model>.save(update_fields=[...])`` and you didn't include the primary key in the ``.only(...)``,
then ``IncompleteInstanceError`` will be raised indicating that updates can't be done without the primary key being known.
* If you do a ``<model>.save(update_fields=[...])`` and one of the fields in ``update_fields`` was not in the ``.only(...)``,
then ``IncompleteInstanceError`` as that field is not available to be updated.
"""
queryset = self._clone()
queryset._fields_for_select = fields_for_select
return queryset
[docs] def force_index(self, *index_names: str) -> "QuerySet[MODEL]":
"""
The FORCE INDEX hint acts like USE INDEX (index_list),
with the addition that a table scan is assumed to be very expensive.
"""
if self.capabilities.support_index_hint:
queryset = self._clone()
for index_name in index_names:
queryset._force_indexes.add(index_name)
return queryset
return self
[docs] def use_index(self, *index_names: str) -> "QuerySet[MODEL]":
"""
The USE INDEX (index_list) hint tells MySQL to use only one of the named indexes to find rows in the table.
"""
if self.capabilities.support_index_hint:
queryset = self._clone()
for index_name in index_names:
queryset._use_indexes.add(index_name)
return queryset
return self
[docs] async def explain(self) -> Any:
"""Fetch and return information about the query execution plan.
This is done by executing an ``EXPLAIN`` query whose exact prefix depends
on the database backend, as documented below.
- PostgreSQL: ``EXPLAIN (FORMAT JSON, VERBOSE) ...``
- SQLite: ``EXPLAIN QUERY PLAN ...``
- MySQL: ``EXPLAIN FORMAT=JSON ...``
.. note::
This is only meant to be used in an interactive environment for debugging
and query optimization.
**The output format may (and will) vary greatly depending on the database backend.**
"""
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return await self._db.executor_class(model=self.model, db=self._db).execute_explain(
self.query
)
[docs] def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]":
"""
Executes query in provided db client.
Useful for transactions workaround.
"""
queryset = self._clone()
queryset._db = _db if _db else queryset._db
return queryset
def _join_table_with_select_related(
self,
model: "Type[Model]",
table: Table,
field: str,
forwarded_fields: str,
path: Iterable[Optional[str]],
) -> Tuple[Table, str]:
if field in model._meta.fields_db_projection and forwarded_fields:
raise FieldError(f'Field "{field}" for model "{model.__name__}" is not relation')
field_object = cast(RelationalField, model._meta.fields_map.get(field))
if not field_object:
raise FieldError(f'Unknown field "{field}" for model "{model.__name__}"')
table = self._join_table_by_field(table, field, field_object)
related_fields = field_object.related_model._meta.db_fields
append_item = (
field_object.related_model,
len(related_fields),
field,
model,
path,
)
if append_item not in self._select_related_idx:
self._select_related_idx.append(append_item)
for related_field in related_fields:
self.query = self.query.select(
table[related_field].as_(f"{table.get_table_name()}.{related_field}")
)
if forwarded_fields:
field, __, forwarded_fields_ = forwarded_fields.partition("__")
self.query = self._join_table_with_select_related(
model=field_object.related_model,
table=table,
field=field,
forwarded_fields=forwarded_fields_,
path=(*path, field),
)
return self.query
return self.query
def _make_query(self) -> None:
# clean tmp records first
self._select_related_idx = []
self._joined_tables = []
table = self.model._meta.basetable
if self._fields_for_select:
append_item = (
self.model,
len(self._fields_for_select),
table,
self.model,
(None,),
)
if append_item not in self._select_related_idx:
self._select_related_idx.append(append_item)
db_fields_for_select = [
table[self.model._meta.fields_db_projection[field]].as_(field)
for field in self._fields_for_select
]
self.query = copy(self.model._meta.basequery).select(*db_fields_for_select)
else:
self.query = copy(self.model._meta.basequery_all_fields)
append_item = (
self.model,
len(self.model._meta.db_fields) + len(self._annotations),
table,
self.model,
(None,),
)
if append_item not in self._select_related_idx:
self._select_related_idx.append(append_item)
self.resolve_ordering(
self.model, self.model._meta.basetable, self._orderings, self._annotations
)
self.resolve_filters(
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
)
if self._limit is not None:
self.query._limit = self._limit
if self._offset:
self.query._offset = self._offset
if self._distinct:
self.query._distinct = True
if self._select_for_update:
self.query = self.query.for_update(
self._select_for_update_nowait,
self._select_for_update_skip_locked,
self._select_for_update_of,
)
if self._select_related:
for field in self._select_related:
field, __, forwarded_fields = field.partition("__")
self.query = self._join_table_with_select_related(
model=self.model,
table=self.model._meta.basetable,
field=field,
forwarded_fields=forwarded_fields,
path=(None, field),
)
if self._force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self._force_indexes)
if self._use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self._use_indexes)
def __await__(self) -> Generator[Any, None, List[MODEL]]:
if self._db is None:
self._db = self._choose_db(self._select_for_update) # type: ignore
self._make_query()
return self._execute().__await__()
async def __aiter__(self) -> AsyncIterator[MODEL]:
for val in await self:
yield val
async def _execute(self) -> List[MODEL]:
instance_list = await self._db.executor_class(
model=self.model,
db=self._db,
prefetch_map=self._prefetch_map,
prefetch_queries=self._prefetch_queries,
select_related_idx=self._select_related_idx,
).execute_select(self.query, custom_fields=list(self._annotations.keys()))
if self._single:
if len(instance_list) == 1:
return instance_list[0]
if not instance_list:
if self._raise_does_not_exist:
raise DoesNotExist(self.model)
return None # type: ignore
raise MultipleObjectsReturned(self.model)
return instance_list
[docs]class UpdateQuery(AwaitableQuery):
__slots__ = (
"update_kwargs",
"q_objects",
"annotations",
"custom_filters",
"orderings",
"limit",
"values",
)
def __init__(
self,
model: Type[MODEL],
update_kwargs: Dict[str, Any],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
) -> None:
super().__init__(model)
self.update_kwargs = update_kwargs
self.q_objects = q_objects
self.annotations = annotations
self.custom_filters = custom_filters
self._db = db
self.limit = limit
self.orderings = orderings
self.values: List[Any] = []
def _make_query(self) -> None:
table = self.model._meta.basetable
self.query = self._db.query_class.update(table)
if self.capabilities.support_update_limit_order_by and self.limit:
self.query._limit = self.limit
self.resolve_ordering(self.model, table, self.orderings, self.annotations)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
# Need to get executor to get correct column_map
executor = self._db.executor_class(model=self.model, db=self._db)
count = 0
for key, value in self.update_kwargs.items():
field_object = self.model._meta.fields_map.get(key)
if not field_object:
raise FieldError(f"Unknown keyword argument {key} for model {self.model}")
if field_object.pk:
raise IntegrityError(f"Field {key} is PK and can not be updated")
if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)):
fk_field: str = field_object.source_field # type: ignore
db_field = self.model._meta.fields_map[fk_field].source_field
value = executor.column_map[fk_field](
getattr(value, field_object.to_field_instance.model_field_name),
None,
)
else:
try:
db_field = self.model._meta.fields_db_projection[key]
except KeyError:
raise FieldError(f"Field {key} is virtual and can not be updated")
if isinstance(value, Term):
value = F.resolver_arithmetic_expression(self.model, value)[0]
elif isinstance(value, Function):
value = value.resolve(self.model, table)["field"]
else:
value = executor.column_map[key](value, None)
if isinstance(value, Term):
self.query = self.query.set(db_field, value)
else:
self.query = self.query.set(db_field, executor.parameter(count))
self.values.append(value)
count += 1
def __await__(self) -> Generator[Any, None, int]:
if self._db is None:
self._db = self._choose_db(True) # type: ignore
self._make_query()
return self._execute().__await__()
async def _execute(self) -> int:
return (await self._db.execute_query(str(self.query), self.values))[0]
[docs]class DeleteQuery(AwaitableQuery):
__slots__ = (
"q_objects",
"annotations",
"custom_filters",
"orderings",
"limit",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
) -> None:
super().__init__(model)
self.q_objects = q_objects
self.annotations = annotations
self.custom_filters = custom_filters
self._db = db
self.limit = limit
self.orderings = orderings
def _make_query(self) -> None:
self.query = copy(self.model._meta.basequery)
if self.capabilities.support_update_limit_order_by and self.limit:
self.query._limit = self.limit
self.resolve_ordering(
model=self.model,
table=self.model._meta.basetable,
orderings=self.orderings,
annotations=self.annotations,
)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
self.query._delete_from = True
def __await__(self) -> Generator[Any, None, int]:
if self._db is None:
self._db = self._choose_db(True) # type: ignore
self._make_query()
return self._execute().__await__()
async def _execute(self) -> int:
return (await self._db.execute_query(str(self.query)))[0]
[docs]class ExistsQuery(AwaitableQuery):
__slots__ = (
"q_objects",
"annotations",
"custom_filters",
"force_indexes",
"use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model)
self.q_objects = q_objects
self.annotations = annotations
self.custom_filters = custom_filters
self._db = db
self.force_indexes = force_indexes
self.use_indexes = use_indexes
def _make_query(self) -> None:
self.query = copy(self.model._meta.basequery)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
self.query._limit = 1
self.query._select_other(ValueWrapper(1))
if self.force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self.force_indexes)
if self.use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self.use_indexes)
def __await__(self) -> Generator[Any, None, bool]:
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return self._execute().__await__()
async def _execute(self) -> bool:
result, _ = await self._db.execute_query(str(self.query))
return bool(result)
[docs]class CountQuery(AwaitableQuery):
__slots__ = (
"q_objects",
"annotations",
"custom_filters",
"limit",
"offset",
"force_indexes",
"use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
offset: Optional[int],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model)
self.q_objects = q_objects
self.annotations = annotations
self.custom_filters = custom_filters
self.limit = limit
self.offset = offset or 0
self._db = db
self.force_indexes = force_indexes
self.use_indexes = use_indexes
def _make_query(self) -> None:
self.query = copy(self.model._meta.basequery)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
count_term = Count("*")
if self.query._groupbys:
count_term = count_term.over()
self.query._select_other(count_term)
if self.force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self.force_indexes)
if self.use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self.use_indexes)
def __await__(self) -> Generator[Any, None, int]:
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return self._execute().__await__()
async def _execute(self) -> int:
_, result = await self._db.execute_query(str(self.query))
if not result:
return 0
count = list(dict(result[0]).values())[0] - self.offset
if self.limit and count > self.limit:
return self.limit
return count
[docs]class FieldSelectQuery(AwaitableQuery):
# pylint: disable=W0223
def __init__(self, model: Type[MODEL], annotations: Dict[str, Any]) -> None:
super().__init__(model)
self.annotations = annotations
def _join_table_with_forwarded_fields(
self, model: Type[MODEL], table: Table, field: str, forwarded_fields: str
) -> Tuple[Table, str]:
if field in model._meta.fields_db_projection and not forwarded_fields:
return table, model._meta.fields_db_projection[field]
if field in model._meta.fields_db_projection and forwarded_fields:
raise FieldError(f'Field "{field}" for model "{model.__name__}" is not relation')
if field in self.model._meta.fetch_fields and not forwarded_fields:
raise ValueError(
'Selecting relation "{}" is not possible, select concrete '
"field on related model".format(field)
)
field_object = cast(RelationalField, model._meta.fields_map.get(field))
if not field_object:
raise FieldError(f'Unknown field "{field}" for model "{model.__name__}"')
table = self._join_table_by_field(table, field, field_object)
field, __, forwarded_fields_ = forwarded_fields.partition("__")
return self._join_table_with_forwarded_fields(
model=field_object.related_model,
table=table,
field=field,
forwarded_fields=forwarded_fields_,
)
def add_field_to_select_query(self, field: str, return_as: str) -> None:
table = self.model._meta.basetable
if field in self.annotations:
self._annotations[return_as] = self.annotations[field]
return
if field in self.model._meta.fields_db_projection:
db_field = self.model._meta.fields_db_projection[field]
self.query._select_field(table[db_field].as_(return_as))
return
if field in self.model._meta.fetch_fields:
raise ValueError(
'Selecting relation "{}" is not possible, select '
"concrete field on related model".format(field)
)
field_, __, forwarded_fields = field.partition("__")
if field_ in self.model._meta.fetch_fields:
related_table, related_db_field = self._join_table_with_forwarded_fields(
model=self.model,
table=table,
field=field_,
forwarded_fields=forwarded_fields,
)
self.query._select_field(related_table[related_db_field].as_(return_as))
return
raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"')
def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable:
if field in model._meta.fetch_fields:
# return as is to get whole model objects
return lambda x: x
if field in (x[1] for x in model._meta.db_native_fields):
return lambda x: x
if field in self.annotations:
annotation = self.annotations[field]
field_object = getattr(annotation, "field_object", None)
if field_object:
return field_object.to_python_value
return lambda x: x
if field in model._meta.fields_map:
return model._meta.fields_map[field].to_python_value
field_, __, forwarded_fields = field.partition("__")
if field_ in model._meta.fetch_fields:
new_model = model._meta.fields_map[field_].related_model # type: ignore
return self.resolve_to_python_value(new_model, forwarded_fields)
raise FieldError(f'Unknown field "{field}" for model "{model}"')
def _resolve_group_bys(self, *field_names: str):
group_bys = []
for field_name in field_names:
if field_name in self._annotations:
group_bys.append(Term(field_name))
continue
field, __, forwarded_fields = field_name.partition("__")
related_table, related_db_field = self._join_table_with_forwarded_fields(
model=self.model,
table=self.model._meta.basetable,
field=field,
forwarded_fields=forwarded_fields,
)
field = related_table[related_db_field].as_(field_name)
group_bys.append(field)
return group_bys
[docs]class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]):
__slots__ = (
"flat",
"fields",
"limit",
"offset",
"distinct",
"orderings",
"annotations",
"custom_filters",
"q_objects",
"single",
"raise_does_not_exist",
"fields_for_select_list",
"group_bys",
"force_indexes",
"use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
single: bool,
raise_does_not_exist: bool,
fields_for_select_list: Union[Tuple[str, ...], List[str]],
limit: Optional[int],
offset: Optional[int],
distinct: bool,
orderings: List[Tuple[str, str]],
flat: bool,
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
group_bys: Tuple[str, ...],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model, annotations)
if flat and (len(fields_for_select_list) != 1):
raise TypeError("You can flat value_list only if contains one field")
fields_for_select = {str(i): field for i, field in enumerate(fields_for_select_list)}
self.fields = fields_for_select
self.limit = limit
self.offset = offset
self.distinct = distinct
self.orderings = orderings
self.custom_filters = custom_filters
self.q_objects = q_objects
self.single = single
self.raise_does_not_exist = raise_does_not_exist
self.fields_for_select_list = fields_for_select_list
self.flat = flat
self._db = db
self.group_bys = group_bys
self.force_indexes = force_indexes
self.use_indexes = use_indexes
def _make_query(self) -> None:
self._joined_tables = []
self.query = copy(self.model._meta.basequery)
for positional_number, field in self.fields.items():
self.add_field_to_select_query(field, positional_number)
self.resolve_ordering(
model=self.model,
table=self.model._meta.basetable,
orderings=self.orderings,
annotations=self.annotations,
)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
if self.limit:
self.query._limit = self.limit
if self.offset:
self.query._offset = self.offset
if self.distinct:
self.query._distinct = True
if self.group_bys:
self.query._groupbys = self._resolve_group_bys(*self.group_bys)
if self.force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self.force_indexes)
if self.use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self.use_indexes)
@overload
def __await__(
self: "ValuesListQuery[Literal[False]]",
) -> Generator[Any, None, List[Tuple[Any, ...]]]: ...
@overload
def __await__(
self: "ValuesListQuery[Literal[True]]",
) -> Generator[Any, None, Tuple[Any, ...]]: ...
def __await__(self) -> Generator[Any, None, Union[List[Any], Tuple[Any, ...]]]:
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return self._execute().__await__() # pylint: disable=E1101
async def __aiter__(self: "ValuesListQuery[Any]") -> AsyncIterator[Any]:
for val in await self:
yield val
async def _execute(self) -> Union[List[Any], Tuple]:
_, result = await self._db.execute_query(str(self.query))
columns = [
(key, self.resolve_to_python_value(self.model, name))
for key, name in self.fields.items()
]
if self.flat:
func = columns[0][1]
flatmap = lambda entry: func(entry["0"]) # noqa
lst_values = list(map(flatmap, result))
else:
listmap = lambda entry: tuple(func(entry[column]) for column, func in columns) # noqa
lst_values = list(map(listmap, result))
if self.single:
if len(lst_values) == 1:
return lst_values[0]
if not lst_values:
if self.raise_does_not_exist:
raise DoesNotExist(self.model)
return None # type: ignore
raise MultipleObjectsReturned(self.model)
return lst_values
[docs]class ValuesQuery(FieldSelectQuery, Generic[SINGLE]):
__slots__ = (
"fields_for_select",
"limit",
"offset",
"distinct",
"orderings",
"annotations",
"custom_filters",
"q_objects",
"single",
"raise_does_not_exist",
"group_bys",
"force_indexes",
"use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
single: bool,
raise_does_not_exist: bool,
fields_for_select: Dict[str, str],
limit: Optional[int],
offset: Optional[int],
distinct: bool,
orderings: List[Tuple[str, str]],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
group_bys: Tuple[str, ...],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model, annotations)
self.fields_for_select = fields_for_select
self.limit = limit
self.offset = offset
self.distinct = distinct
self.orderings = orderings
self.custom_filters = custom_filters
self.q_objects = q_objects
self.single = single
self.raise_does_not_exist = raise_does_not_exist
self._db = db
self.group_bys = group_bys
self.force_indexes = force_indexes
self.use_indexes = use_indexes
def _make_query(self) -> None:
self._joined_tables = []
self.query = copy(self.model._meta.basequery)
for return_as, field in self.fields_for_select.items():
self.add_field_to_select_query(field, return_as)
self.resolve_ordering(
model=self.model,
table=self.model._meta.basetable,
orderings=self.orderings,
annotations=self.annotations,
)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
if self.limit:
self.query._limit = self.limit
if self.offset:
self.query._offset = self.offset
if self.distinct:
self.query._distinct = True
if self.group_bys:
self.query._groupbys = self._resolve_group_bys(*self.group_bys)
if self.force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self.force_indexes)
if self.use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self.use_indexes)
@overload
def __await__(
self: "ValuesQuery[Literal[False]]",
) -> Generator[Any, None, List[Dict[str, Any]]]: ...
@overload
def __await__(
self: "ValuesQuery[Literal[True]]",
) -> Generator[Any, None, Dict[str, Any]]: ...
def __await__(
self,
) -> Generator[Any, None, Union[List[Dict[str, Any]], Dict[str, Any]]]:
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return self._execute().__await__() # pylint: disable=E1101
async def __aiter__(self: "ValuesQuery[Any]") -> AsyncIterator[Dict[str, Any]]:
for val in await self:
yield val
async def _execute(self) -> Union[List[dict], Dict]:
result = await self._db.execute_query_dict(str(self.query))
columns = [
val
for val in [
(alias, self.resolve_to_python_value(self.model, field_name))
for alias, field_name in self.fields_for_select.items()
]
if not isinstance(val[1], types.LambdaType)
]
if columns:
for row in result:
for col, func in columns:
row[col] = func(row[col])
if self.single:
if len(result) == 1:
return result[0]
if not result:
if self.raise_does_not_exist:
raise DoesNotExist(self.model)
return None # type: ignore
raise MultipleObjectsReturned(self.model)
return result
[docs]class RawSQLQuery(AwaitableQuery):
__slots__ = ("_sql", "_db")
def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str):
super().__init__(model)
self._sql = sql
self._db = db
def _make_query(self) -> None:
self.query = RawSQL(self._sql)
async def _execute(self) -> Any:
instance_list = await self._db.executor_class(
model=self.model,
db=self._db,
).execute_select(self.query)
return instance_list
def __await__(self) -> Generator[Any, None, List[MODEL]]:
if self._db is None:
self._db = self._choose_db() # type: ignore
self._make_query()
return self._execute().__await__()
[docs]class BulkUpdateQuery(UpdateQuery, Generic[MODEL]):
__slots__ = ("objects", "fields", "batch_size", "queries")
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
objects: Iterable[MODEL],
fields: Iterable[str],
batch_size: Optional[int] = None,
):
super().__init__(
model,
update_kwargs={},
db=db,
q_objects=q_objects,
annotations=annotations,
custom_filters=custom_filters,
limit=limit,
orderings=orderings,
)
self.objects = objects
self.fields = fields
self.batch_size = batch_size
self.queries: List[QueryBuilder] = []
def _make_query(self) -> None:
table = self.model._meta.basetable
self.query = self._db.query_class.update(table)
if self.capabilities.support_update_limit_order_by and self.limit:
self.query._limit = self.limit
self.resolve_ordering(
model=self.model,
table=table,
orderings=self.orderings,
annotations=self.annotations,
)
self.resolve_filters(
model=self.model,
q_objects=self.q_objects,
annotations=self.annotations,
custom_filters=self.custom_filters,
)
executor = self._db.executor_class(model=self.model, db=self._db)
pk_attr = self.model._meta.pk_attr
source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr
pk = Field(source_pk_attr)
for objects_item in chunk(self.objects, self.batch_size):
query = copy(self.query)
for field in self.fields:
case = Case()
pk_list = []
for obj in objects_item:
value = executor.column_map[pk_attr](obj.pk, None)
field_value = obj._meta.fields_map[field].to_db_value(getattr(obj, field), obj)
case.when(
pk == value,
(
Cast(
self.query._wrapper_cls(field_value),
obj._meta.fields_map[field].get_for_dialect(
self._db.schema_generator.DIALECT, "SQL_TYPE"
),
)
if self._db.schema_generator.DIALECT == "postgres"
else self.query._wrapper_cls(field_value)
),
)
pk_list.append(value)
query = query.set(field, case)
query = query.where(pk.isin(pk_list))
self.queries.append(query)
async def _execute(self) -> int:
count = 0
for query in self.queries:
count += (await self._db.execute_query(str(query)))[0]
return count
[docs] def sql(self, **kwargs) -> str:
self.as_query()
return ";".join([str(query) for query in self.queries])
[docs]class BulkCreateQuery(AwaitableQuery, Generic[MODEL]):
__slots__ = (
"objects",
"ignore_conflicts",
"batch_size",
"_db",
"executor",
"insert_query",
"insert_query_all",
"update_fields",
"on_conflict",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
objects: Iterable[MODEL],
batch_size: Optional[int] = None,
ignore_conflicts: bool = False,
update_fields: Optional[Iterable[str]] = None,
on_conflict: Optional[Iterable[str]] = None,
):
super().__init__(model)
self.objects = objects
self.ignore_conflicts = ignore_conflicts
self.batch_size = batch_size
self._db = db
self.update_fields = update_fields
self.on_conflict = on_conflict
def _make_query(self) -> None:
self.executor = self._db.executor_class(model=self.model, db=self._db)
if self.ignore_conflicts or self.update_fields:
regular_columns, columns = self.executor._prepare_insert_columns()
self.insert_query = self.executor._prepare_insert_statement(
columns, ignore_conflicts=self.ignore_conflicts
)
self.insert_query_all = self.insert_query
if self.model._meta.generated_db_fields:
regular_columns_all, columns_all = self.executor._prepare_insert_columns(
include_generated=True
)
self.insert_query_all = self.executor._prepare_insert_statement(
columns_all,
has_generated=False,
ignore_conflicts=self.ignore_conflicts,
)
if self.update_fields:
alias = f"new_{self.model._meta.db_table}"
self.insert_query_all = self.insert_query_all.as_(alias).on_conflict(
*self.on_conflict
)
self.insert_query = self.insert_query.as_(alias).on_conflict(*self.on_conflict)
for update_field in self.update_fields:
self.insert_query_all = self.insert_query_all.do_update(update_field)
self.insert_query = self.insert_query.do_update(update_field)
else:
self.insert_query_all = self.executor.insert_query_all
self.insert_query = self.executor.insert_query
async def _execute(self) -> None:
for instance_chunk in chunk(self.objects, self.batch_size):
values_lists_all = []
values_lists = []
for instance in instance_chunk:
if instance._custom_generated_pk:
values_lists_all.append(
[
self.executor.column_map[field_name](
getattr(instance, field_name), instance
)
for field_name in self.executor.regular_columns_all
]
)
else:
values_lists.append(
[
self.executor.column_map[field_name](
getattr(instance, field_name), instance
)
for field_name in self.executor.regular_columns
]
)
if values_lists_all:
await self._db.execute_many(str(self.insert_query_all), values_lists_all)
if values_lists:
await self._db.execute_many(str(self.insert_query), values_lists)
def __await__(self) -> Generator[Any, None, None]:
if self._db is None:
self._db = self._choose_db(True) # type: ignore
self._make_query()
return self._execute().__await__()
[docs] def sql(self, **kwargs) -> str:
self.as_query()
if self.insert_query and self.insert_query_all:
return ";".join([str(self.insert_query), str(self.insert_query_all)])
return str(self.insert_query or self.insert_query_all)