import types
from copy import copy
from typing import (
TYPE_CHECKING,
Any,
AsyncIterator,
Callable,
Dict,
Generator,
Generic,
Iterable,
List,
Optional,
Set,
Tuple,
Type,
TypeVar,
Union,
cast,
overload,
)
from pypika import JoinType, Order, Table
from pypika.analytics import Count
from pypika.functions import Cast
from pypika.queries import QueryBuilder
from pypika.terms import Case, Field, Star, Term, ValueWrapper
from typing_extensions import Literal, Protocol
from tortoise.backends.base.client import BaseDBAsyncClient, Capabilities
from tortoise.exceptions import (
DoesNotExist,
FieldError,
IntegrityError,
MultipleObjectsReturned,
ParamsError,
)
from tortoise.expressions import Expression, Q, RawSQL, ResolveContext, ResolveResult
from tortoise.fields.relational import (
ForeignKeyFieldInstance,
OneToOneFieldInstance,
RelationalField,
)
from tortoise.filters import FilterInfoDict
from tortoise.query_utils import (
Prefetch,
QueryModifier,
TableCriterionTuple,
get_joins_for_related_field,
)
from tortoise.router import router
from tortoise.utils import chunk
# Empty placeholder - Should never be edited.
QUERY: QueryBuilder = QueryBuilder()
if TYPE_CHECKING: # pragma: nocoverage
from tortoise.models import Model
MODEL = TypeVar("MODEL", bound="Model")
T_co = TypeVar("T_co", covariant=True)
SINGLE = TypeVar("SINGLE", bound=bool)
[docs]class QuerySetSingle(Protocol[T_co]):
"""
Awaiting on this will resolve a single instance of the Model object, and not a sequence.
"""
# pylint: disable=W0104
def __await__(self) -> Generator[Any, None, T_co]: ... # pragma: nocoverage
def prefetch_related(
self, *args: Union[str, Prefetch]
) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def select_related(self, *args: str) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def annotate(
self, **kwargs: Union[Expression, Term]
) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def only(self, *fields_for_select: str) -> "QuerySetSingle[T_co]": ... # pragma: nocoverage
def values_list(
self, *fields_: str, flat: bool = False
) -> "ValuesListQuery[Literal[True]]": ... # pragma: nocoverage
def values(
self, *args: str, **kwargs: str
) -> "ValuesQuery[Literal[True]]": ... # pragma: nocoverage
class AwaitableQuery(Generic[MODEL]):
__slots__ = (
"query",
"model",
"_joined_tables",
"_db",
"capabilities",
"_annotations",
"_custom_filters",
"_q_objects",
)
def __init__(self, model: Type[MODEL]) -> None:
self._joined_tables: List[Table] = []
self.model: "Type[MODEL]" = model
self.query: QueryBuilder = QUERY
self._db: BaseDBAsyncClient = None # type: ignore
self.capabilities: Capabilities = model._meta.db.capabilities
self._annotations: Dict[str, Union[Expression, Term]] = {}
self._custom_filters: Dict[str, FilterInfoDict] = {}
self._q_objects: List[Q] = []
def _choose_db(self, for_write: bool = False) -> BaseDBAsyncClient:
"""
Return the connection that will be used if this query is executed now.
:return: BaseDBAsyncClient:
"""
if self._db:
return self._db
if for_write:
db = router.db_for_write(self.model)
else:
db = router.db_for_read(self.model)
return db or self.model._meta.db
def _choose_db_if_not_chosen(self, for_write: bool = False) -> None:
if self._db is None:
self._db = self._choose_db(for_write) # type: ignore
def resolve_filters(self) -> None:
"""Builds the common filters for a QuerySet."""
has_aggregate = self._resolve_annotate()
modifier = QueryModifier()
for node in self._q_objects:
modifier &= node.resolve(
ResolveContext(
model=self.model,
table=self.model._meta.basetable,
annotations=self._annotations,
custom_filters=self._custom_filters,
)
)
for join in modifier.joins:
if join[0] not in self._joined_tables:
self.query = self.query.join(join[0], how=JoinType.left_outer).on(join[1])
self._joined_tables.append(join[0])
self.query._havings = modifier.having_criterion
self.query._wheres = modifier.where_criterion
if has_aggregate and (self._joined_tables or self.query._havings or self.query._orderbys):
self.query = self.query.groupby(
*[self.model._meta.basetable[field] for field in self.model._meta.db_fields]
)
def _join_table_by_field(
self, table: Table, related_field_name: str, related_field: RelationalField
) -> Table:
joins = get_joins_for_related_field(table, related_field, related_field_name)
for join in joins:
self._join_table(join)
return joins[-1][0]
def _join_table(self, table_criterio_tuple: TableCriterionTuple) -> None:
if table_criterio_tuple[0] not in self._joined_tables:
self.query = self.query.join(table_criterio_tuple[0], how=JoinType.left_outer).on(
table_criterio_tuple[1]
)
self._joined_tables.append(table_criterio_tuple[0])
@staticmethod
def _resolve_ordering_string(ordering: str, reverse: bool = False) -> Tuple[str, Order]:
order_type = Order.asc
if ordering[0] == "-":
field_name = ordering[1:]
order_type = Order.desc
else:
field_name = ordering
if reverse:
order_type = Order.desc if order_type == Order.asc else Order.asc
return field_name, order_type
def resolve_ordering(
self,
model: "Type[Model]",
table: Table,
orderings: Iterable[Tuple[str, Union[str, Order]]],
annotations: Dict[str, Any],
) -> None:
"""
Applies standard ordering to QuerySet.
:param model: The Model this queryset is based on.
:param table: ``pypika.Table`` to keep track of the virtual SQL table
(to allow self referential joins)
:param orderings: What columns/order to order by
:param annotations: Annotations that may be ordered on
:raises FieldError: If a field provided does not exist in model.
"""
# Do not apply default ordering for annotated queries to not mess them up
if not orderings and self.model._meta.ordering and not annotations:
orderings = self.model._meta.ordering
for ordering in orderings:
field_name = ordering[0]
if field_name in model._meta.fetch_fields:
raise FieldError(
"Filtering by relation is not possible. Filter by nested field of related model"
)
related_field_name, __, forwarded = field_name.partition("__")
if related_field_name in model._meta.fetch_fields:
related_field = cast(RelationalField, model._meta.fields_map[related_field_name])
related_table = self._join_table_by_field(table, related_field_name, related_field)
self.resolve_ordering(
related_field.related_model,
related_table,
[(forwarded, ordering[1])],
{},
)
elif field_name in annotations:
if isinstance(annotation := annotations[field_name], Term):
term: Term = annotation
else:
annotation_info = annotation.resolve(
ResolveContext(
model=self.model,
table=table,
annotations=annotations,
custom_filters={},
)
)
term = annotation_info.term
self.query = self.query.orderby(term, order=ordering[1])
else:
field_object = model._meta.fields_map.get(field_name)
if not field_object:
raise FieldError(f"Unknown field {field_name} for model {model.__name__}")
field_name = field_object.source_field or field_name
field = table[field_name]
func = field_object.get_for_dialect(
model._meta.db.capabilities.dialect, "function_cast"
)
if func:
field = func(field_object, field)
self.query = self.query.orderby(field, order=ordering[1])
def _resolve_annotate(self) -> bool:
if not self._annotations:
return False
annotation_info: Dict[str, ResolveResult] = {}
for key, annotation in self._annotations.items():
if isinstance(annotation, Term):
annotation_info[key] = ResolveResult(term=annotation)
else:
annotation_info[key] = annotation.resolve(
ResolveContext(
model=self.model,
table=self.model._meta.basetable,
annotations=self._annotations,
custom_filters=self._custom_filters,
)
)
for key, info in annotation_info.items():
for join in info.joins:
self._join_table(join)
if key in self._annotations:
self.query._select_other(info.term.as_(key)) # type:ignore[arg-type]
return any(info.term.is_aggregate for info in annotation_info.values())
def sql(self, params_inline=False) -> str:
"""
Returns the SQL query that will be executed. By default, it will return the query with
placeholders, but if you set `params_inline=True`, it will inline the parameters.
:param params_inline: Whether to inline the parameters
"""
self._choose_db_if_not_chosen()
self._make_query()
if params_inline:
sql = self.query.get_sql()
else:
sql, _ = self.query.get_parameterized_sql()
return sql
def _make_query(self) -> None:
raise NotImplementedError() # pragma: nocoverage
async def _execute(self) -> Any:
raise NotImplementedError() # pragma: nocoverage
[docs]class QuerySet(AwaitableQuery[MODEL]):
__slots__ = (
"fields",
"_prefetch_map",
"_prefetch_queries",
"_single",
"_raise_does_not_exist",
"_db",
"_limit",
"_offset",
"_fields_for_select",
"_filter_kwargs",
"_orderings",
"_distinct",
"_having",
"_group_bys",
"_select_for_update",
"_select_for_update_nowait",
"_select_for_update_skip_locked",
"_select_for_update_of",
"_select_related",
"_select_related_idx",
"_use_indexes",
"_force_indexes",
)
def __init__(self, model: Type[MODEL]) -> None:
super().__init__(model)
self.fields: Set[str] = model._meta.db_fields
self._prefetch_map: Dict[str, Set[Union[str, Prefetch]]] = {}
self._prefetch_queries: Dict[str, List[Tuple[Optional[str], QuerySet]]] = {}
self._single: bool = False
self._raise_does_not_exist: bool = False
self._limit: Optional[int] = None
self._offset: Optional[int] = None
self._filter_kwargs: Dict[str, Any] = {}
self._orderings: List[Tuple[str, Any]] = []
self._distinct: bool = False
self._having: Dict[str, Any] = {}
self._fields_for_select: Tuple[str, ...] = ()
self._group_bys: Tuple[str, ...] = ()
self._select_for_update: bool = False
self._select_for_update_nowait: bool = False
self._select_for_update_skip_locked: bool = False
self._select_for_update_of: Set[str] = set()
self._select_related: Set[str] = set()
self._select_related_idx: List[
Tuple["Type[Model]", int, Union[Table, str], "Type[Model]", Iterable[Optional[str]]]
] = [] # format with: model,idx,model_name,parent_model
self._force_indexes: Set[str] = set()
self._use_indexes: Set[str] = set()
def _clone(self) -> "QuerySet[MODEL]":
queryset = self.__class__.__new__(self.__class__)
queryset.fields = self.fields
queryset.model = self.model
queryset.query = self.query
queryset.capabilities = self.capabilities
queryset._prefetch_map = copy(self._prefetch_map)
queryset._prefetch_queries = copy(self._prefetch_queries)
queryset._single = self._single
queryset._raise_does_not_exist = self._raise_does_not_exist
queryset._db = self._db
queryset._limit = self._limit
queryset._offset = self._offset
queryset._fields_for_select = self._fields_for_select
queryset._filter_kwargs = copy(self._filter_kwargs)
queryset._orderings = copy(self._orderings)
queryset._joined_tables = copy(self._joined_tables)
queryset._q_objects = copy(self._q_objects)
queryset._distinct = self._distinct
queryset._annotations = copy(self._annotations)
queryset._having = copy(self._having)
queryset._custom_filters = copy(self._custom_filters)
queryset._group_bys = copy(self._group_bys)
queryset._select_for_update = self._select_for_update
queryset._select_for_update_nowait = self._select_for_update_nowait
queryset._select_for_update_skip_locked = self._select_for_update_skip_locked
queryset._select_for_update_of = self._select_for_update_of
queryset._select_related = self._select_related
queryset._select_related_idx = self._select_related_idx
queryset._force_indexes = self._force_indexes
queryset._use_indexes = self._use_indexes
return queryset
def _filter_or_exclude(self, *args: Q, negate: bool, **kwargs: Any) -> "QuerySet[MODEL]":
queryset = self._clone()
for arg in args:
if not isinstance(arg, Q):
raise TypeError("expected Q objects as args")
if negate:
queryset._q_objects.append(~arg)
else:
queryset._q_objects.append(arg)
for key, value in kwargs.items():
if negate:
queryset._q_objects.append(~Q(**{key: value}))
else:
queryset._q_objects.append(Q(**{key: value}))
return queryset
[docs] def filter(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]":
"""
Filters QuerySet by given kwargs. You can filter by related objects like this:
.. code-block:: python3
Team.filter(events__tournament__name='Test')
You can also pass Q objects to filters as args.
"""
return self._filter_or_exclude(negate=False, *args, **kwargs)
[docs] def exclude(self, *args: Q, **kwargs: Any) -> "QuerySet[MODEL]":
"""
Same as .filter(), but with appends all args with NOT
"""
return self._filter_or_exclude(negate=True, *args, **kwargs)
def _parse_orderings(
self, orderings: Tuple[str, ...], reverse=False
) -> List[Tuple[str, Order]]:
"""
Convert ordering from strings to standard items for queryset.
:param orderings: What columns/order to order by
:param reverse: Whether reverse order
:return: standard ordering for QuerySet.
"""
new_ordering = []
for ordering in orderings:
field_name, order_type = self._resolve_ordering_string(ordering, reverse=reverse)
if not (
field_name.split("__")[0] in self.model._meta.fields
or field_name in self._annotations
):
raise FieldError(f"Unknown field {field_name} for model {self.model.__name__}")
new_ordering.append((field_name, order_type))
return new_ordering
[docs] def order_by(self, *orderings: str) -> "QuerySet[MODEL]":
"""
Accept args to filter by in format like this:
.. code-block:: python3
.order_by('name', '-tournament__name')
Supports ordering by related models too.
A '-' before the name will result in descending sort order, default is ascending.
:raises FieldError: If unknown field has been provided.
"""
queryset = self._clone()
queryset._orderings = self._parse_orderings(orderings)
return queryset
def _as_single(self) -> QuerySetSingle[Optional[MODEL]]:
self._single = True
self._limit = 1
return cast(QuerySetSingle[Optional[MODEL]], self)
[docs] def latest(self, *orderings: str) -> QuerySetSingle[Optional[MODEL]]:
"""
Returns the most recent object by ordering descending on the providers fields.
:params orderings: Fields to order by.
:raises FieldError: If unknown or no fields has been provided.
"""
if not orderings:
raise FieldError("No fields passed")
queryset = self._clone()
queryset._orderings = self._parse_orderings(orderings, reverse=True)
return queryset._as_single()
[docs] def earliest(self, *orderings: str) -> QuerySetSingle[Optional[MODEL]]:
"""
Returns the earliest object by ordering ascending on the specified field.
:params orderings: Fields to order by.
:raises FieldError: If unknown or no fields has been provided.
"""
if not orderings:
raise FieldError("No fields passed")
queryset = self._clone()
queryset._orderings = self._parse_orderings(orderings)
return queryset._as_single()
[docs] def limit(self, limit: int) -> "QuerySet[MODEL]":
"""
Limits QuerySet to given length.
:raises ParamsError: Limit should be non-negative number.
"""
if limit < 0:
raise ParamsError("Limit should be non-negative number")
queryset = self._clone()
queryset._limit = limit
return queryset
[docs] def offset(self, offset: int) -> "QuerySet[MODEL]":
"""
Query offset for QuerySet.
:raises ParamsError: Offset should be non-negative number.
"""
if offset < 0:
raise ParamsError("Offset should be non-negative number")
queryset = self._clone()
queryset._offset = offset
if self.capabilities.requires_limit and queryset._limit is None:
queryset._limit = 1000000
return queryset
[docs] def __getitem__(self, key: slice) -> "QuerySet[MODEL]":
"""
Query offset and limit for Queryset.
:raises ParamsError: QuerySet indices must be slices.
:raises ParamsError: Slice steps should be 1 or None.
:raises ParamsError: Slice start should be non-negative number or None.
:raises ParamsError: Slice stop should be non-negative number greater that slice start,
or None.
"""
if not isinstance(key, slice):
raise ParamsError("QuerySet indices must be slices.")
if not (key.step is None or (isinstance(key.step, int) and key.step == 1)):
raise ParamsError("Slice steps should be 1 or None.")
start = key.start if key.start is not None else 0
if not isinstance(start, int) or start < 0:
raise ParamsError("Slice start should be non-negative number or None.")
if key.stop is not None and (not isinstance(key.stop, int) or key.stop <= start):
raise ParamsError(
"Slice stop should be non-negative number greater that slice start, or None.",
)
queryset = self.offset(start)
if key.stop:
queryset = queryset.limit(key.stop - start)
return queryset
[docs] def distinct(self) -> "QuerySet[MODEL]":
"""
Make QuerySet distinct.
Only makes sense in combination with a ``.values()`` or ``.values_list()`` as it
precedes all the fetched fields with a distinct.
"""
queryset = self._clone()
queryset._distinct = True
return queryset
[docs] def select_for_update(
self, nowait: bool = False, skip_locked: bool = False, of: Tuple[str, ...] = ()
) -> "QuerySet[MODEL]":
"""
Make QuerySet select for update.
Returns a queryset that will lock rows until the end of the transaction,
generating a SELECT ... FOR UPDATE SQL statement on supported databases.
"""
if self.capabilities.support_for_update:
queryset = self._clone()
queryset._select_for_update = True
queryset._select_for_update_nowait = nowait
queryset._select_for_update_skip_locked = skip_locked
queryset._select_for_update_of = set(of)
return queryset
return self
[docs] def annotate(self, **kwargs: Union[Expression, Term]) -> "QuerySet[MODEL]":
"""
Annotate result with aggregation or function result.
:raises TypeError: Value of kwarg is expected to be a ``Function`` instance.
"""
from tortoise.models import get_filters_for_field
queryset = self._clone()
for key, annotation in kwargs.items():
# if not isinstance(annotation, (Function, Term)):
# raise TypeError("value is expected to be Function/Term instance")
queryset._annotations[key] = annotation
queryset._custom_filters.update(get_filters_for_field(key, None, key))
return queryset
[docs] def group_by(self, *fields: str) -> "QuerySet[MODEL]":
"""
Make QuerySet returns list of dict or tuple with group by.
Must call before .values() or .values_list()
"""
queryset = self._clone()
queryset._group_bys = fields
return queryset
[docs] def values_list(self, *fields_: str, flat: bool = False) -> "ValuesListQuery[Literal[False]]":
"""
Make QuerySet returns list of tuples for given args instead of objects.
If call after `.get()`, `.get_or_none()` or `.first()` return tuples for given args instead of object.
If ```flat=True`` and only one arg is passed can return flat list or just scalar.
If no arguments are passed it will default to a tuple containing all fields
in order of declaration.
"""
fields_for_select_list = fields_ or [
field for field in self.model._meta.fields_map if field in self.model._meta.db_fields
] + list(self._annotations.keys())
return ValuesListQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
single=self._single,
raise_does_not_exist=self._raise_does_not_exist,
flat=flat,
fields_for_select_list=fields_for_select_list,
distinct=self._distinct,
limit=self._limit,
offset=self._offset,
orderings=self._orderings,
annotations=self._annotations,
custom_filters=self._custom_filters,
group_bys=self._group_bys,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def values(self, *args: str, **kwargs: str) -> "ValuesQuery[Literal[False]]":
"""
Make QuerySet return dicts instead of objects.
If call after `.get()`, `.get_or_none()` or `.first()` return dict instead of object.
Can pass names of fields to fetch, or as a ``field_name='name_in_dict'`` kwarg.
If no arguments are passed it will default to a dict containing all fields.
:raises FieldError: If duplicate key has been provided.
"""
if args or kwargs:
fields_for_select: Dict[str, str] = {}
for field in args:
if field in fields_for_select:
raise FieldError(f"Duplicate key {field}")
fields_for_select[field] = field
for return_as, field in kwargs.items():
if return_as in fields_for_select:
raise FieldError(f"Duplicate key {return_as}")
fields_for_select[return_as] = field
else:
_fields = [
field
for field in self.model._meta.fields_map.keys()
if field in self.model._meta.fields_db_projection.keys()
] + list(self._annotations.keys())
fields_for_select = {field: field for field in _fields}
return ValuesQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
single=self._single,
raise_does_not_exist=self._raise_does_not_exist,
fields_for_select=fields_for_select,
distinct=self._distinct,
limit=self._limit,
offset=self._offset,
orderings=self._orderings,
annotations=self._annotations,
custom_filters=self._custom_filters,
group_bys=self._group_bys,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def delete(self) -> "DeleteQuery":
"""
Delete all objects in QuerySet.
"""
return DeleteQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
orderings=self._orderings,
)
[docs] def update(self, **kwargs: Any) -> "UpdateQuery":
"""
Update all objects in QuerySet with given kwargs.
.. admonition: Example:
.. code-block:: py3
await Employee.filter(occupation='developer').update(salary=5000)
Will instead of returning a resultset, update the data in the DB itself.
"""
return UpdateQuery(
db=self._db,
model=self.model,
update_kwargs=kwargs,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
orderings=self._orderings,
)
[docs] def count(self) -> "CountQuery":
"""
Return count of objects in queryset instead of objects.
"""
return CountQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
offset=self._offset,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def exists(self) -> "ExistsQuery":
"""
Return True/False whether queryset exists.
"""
return ExistsQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
force_indexes=self._force_indexes,
use_indexes=self._use_indexes,
)
[docs] def all(self) -> "QuerySet[MODEL]":
"""
Return the whole QuerySet.
Essentially a no-op except as the only operation.
"""
return self._clone()
[docs] def raw(self, sql: str) -> "RawSQLQuery":
"""
Return the QuerySet from raw SQL
"""
return RawSQLQuery(model=self.model, db=self._db, sql=sql)
[docs] def first(self) -> QuerySetSingle[Optional[MODEL]]:
"""
Limit queryset to one object and return one object instead of list.
"""
queryset = self._clone()
return queryset._as_single()
[docs] def last(self) -> QuerySetSingle[Optional[MODEL]]:
"""
Limit queryset to one object and return the last object instead of list.
"""
queryset = self._clone()
if queryset._orderings:
new_ordering = [
(field, Order.desc if order_type == Order.asc else Order.asc)
for field, order_type in queryset._orderings
]
elif pk := self.model._meta.pk:
new_ordering = [(pk.model_field_name, Order.desc)]
else:
raise FieldError(
f"QuerySet has no ordering and model {self.model.__name__} has no pk defined"
)
queryset._orderings = new_ordering
return queryset._as_single()
[docs] def get(self, *args: Q, **kwargs: Any) -> QuerySetSingle[MODEL]:
"""
Fetch exactly one object matching the parameters.
"""
queryset = self.filter(*args, **kwargs)
queryset._limit = 2
queryset._single = True
queryset._raise_does_not_exist = True
return queryset # type: ignore
[docs] async def in_bulk(
self, id_list: Iterable[Union[str, int]], field_name: str
) -> Dict[str, MODEL]:
"""
Return a dictionary mapping each of the given IDs to the object with
that ID. If `id_list` isn't provided, evaluate the entire QuerySet.
:param id_list: A list of field values
:param field_name: Must be a unique field
"""
objs = await self.filter(**{f"{field_name}__in": id_list})
return {getattr(obj, field_name): obj for obj in objs}
[docs] def bulk_create(
self,
objects: Iterable[MODEL],
batch_size: Optional[int] = None,
ignore_conflicts: bool = False,
update_fields: Optional[Iterable[str]] = None,
on_conflict: Optional[Iterable[str]] = None,
) -> "BulkCreateQuery[MODEL]":
"""
This method inserts the provided list of objects into the database in an efficient manner
(generally only 1 query, no matter how many objects there are).
:param on_conflict: On conflict index name
:param update_fields: Update fields when conflicts
:param ignore_conflicts: Ignore conflicts when inserting
:param objects: List of objects to bulk create
:param batch_size: How many objects are created in a single query
:raises ValueError: If params do not meet specifications
"""
if ignore_conflicts and update_fields:
raise ValueError(
"ignore_conflicts and update_fields are mutually exclusive.",
)
if not ignore_conflicts:
if (update_fields and not on_conflict) or (on_conflict and not update_fields):
raise ValueError("update_fields and on_conflict need set in same time.")
return BulkCreateQuery(
db=self._db,
model=self.model,
objects=objects,
batch_size=batch_size,
ignore_conflicts=ignore_conflicts,
update_fields=update_fields,
on_conflict=on_conflict,
)
[docs] def bulk_update(
self,
objects: Iterable[MODEL],
fields: Iterable[str],
batch_size: Optional[int] = None,
) -> "BulkUpdateQuery[MODEL]":
"""
Update the given fields in each of the given objects in the database.
:param objects: List of objects to bulk create
:param fields: The fields to update
:param batch_size: How many objects are created in a single query
:raises ValueError: If objects have no primary key set
"""
if any(obj.pk is None for obj in objects):
raise ValueError("All bulk_update() objects must have a primary key set.")
return BulkUpdateQuery(
db=self._db,
model=self.model,
q_objects=self._q_objects,
annotations=self._annotations,
custom_filters=self._custom_filters,
limit=self._limit,
orderings=self._orderings,
objects=objects,
fields=fields,
batch_size=batch_size,
)
[docs] def get_or_none(self, *args: Q, **kwargs: Any) -> QuerySetSingle[Optional[MODEL]]:
"""
Fetch exactly one object matching the parameters.
"""
queryset = self.filter(*args, **kwargs)
queryset._limit = 2
queryset._single = True
return queryset # type: ignore
[docs] def only(self, *fields_for_select: str) -> "QuerySet[MODEL]":
"""
Fetch ONLY the specified fields to create a partial model.
Persisting changes on the model is allowed only when:
* All the fields you want to update is specified in ``<model>.save(update_fields=[...])``
* You included the Model primary key in the `.only(...)``
To protect against common mistakes we ensure that errors get raised:
* If you access a field that is not specified, you will get an ``AttributeError``.
* If you do a ``<model>.save()`` a ``IncompleteInstanceError`` will be raised as the model is, as requested, incomplete.
* If you do a ``<model>.save(update_fields=[...])`` and you didn't include the primary key in the ``.only(...)``,
then ``IncompleteInstanceError`` will be raised indicating that updates can't be done without the primary key being known.
* If you do a ``<model>.save(update_fields=[...])`` and one of the fields in ``update_fields`` was not in the ``.only(...)``,
then ``IncompleteInstanceError`` as that field is not available to be updated.
"""
queryset = self._clone()
queryset._fields_for_select = fields_for_select
return queryset
[docs] def force_index(self, *index_names: str) -> "QuerySet[MODEL]":
"""
The FORCE INDEX hint acts like USE INDEX (index_list),
with the addition that a table scan is assumed to be very expensive.
"""
if self.capabilities.support_index_hint:
queryset = self._clone()
for index_name in index_names:
queryset._force_indexes.add(index_name)
return queryset
return self
[docs] def use_index(self, *index_names: str) -> "QuerySet[MODEL]":
"""
The USE INDEX (index_list) hint tells MySQL to use only one of the named indexes to find rows in the table.
"""
if self.capabilities.support_index_hint:
queryset = self._clone()
for index_name in index_names:
queryset._use_indexes.add(index_name)
return queryset
return self
[docs] async def explain(self) -> Any:
"""Fetch and return information about the query execution plan.
This is done by executing an ``EXPLAIN`` query whose exact prefix depends
on the database backend, as documented below.
- PostgreSQL: ``EXPLAIN (FORMAT JSON, VERBOSE) ...``
- SQLite: ``EXPLAIN QUERY PLAN ...``
- MySQL: ``EXPLAIN FORMAT=JSON ...``
.. note::
This is only meant to be used in an interactive environment for debugging
and query optimization.
**The output format may (and will) vary greatly depending on the database backend.**
"""
self._choose_db_if_not_chosen()
self._make_query()
return await self._db.executor_class(model=self.model, db=self._db).execute_explain(
self.query.get_sql()
)
[docs] def using_db(self, _db: Optional[BaseDBAsyncClient]) -> "QuerySet[MODEL]":
"""
Executes query in provided db client.
Useful for transactions workaround.
"""
queryset = self._clone()
queryset._db = _db if _db else queryset._db
return queryset
def _join_table_with_select_related(
self,
model: "Type[Model]",
table: Table,
field: str,
forwarded_fields: str,
path: Iterable[Optional[str]],
) -> "QueryBuilder":
if field in model._meta.fields_db_projection and forwarded_fields:
raise FieldError(f'Field "{field}" for model "{model.__name__}" is not relation')
field_object = cast(RelationalField, model._meta.fields_map.get(field))
if not field_object:
raise FieldError(f'Unknown field "{field}" for model "{model.__name__}"')
table = self._join_table_by_field(table, field, field_object)
related_fields = field_object.related_model._meta.db_fields
append_item = (
field_object.related_model,
len(related_fields),
field,
model,
path,
)
if append_item not in self._select_related_idx:
self._select_related_idx.append(append_item)
for related_field in related_fields:
self.query = self.query.select(
table[related_field].as_(f"{table.get_table_name()}.{related_field}")
)
if forwarded_fields:
field, __, forwarded_fields_ = forwarded_fields.partition("__")
self.query = self._join_table_with_select_related(
model=field_object.related_model,
table=table,
field=field,
forwarded_fields=forwarded_fields_,
path=(*path, field),
)
return self.query
def _make_query(self) -> None:
# clean tmp records first
self._select_related_idx = []
self._joined_tables = []
table = self.model._meta.basetable
if self._fields_for_select:
append_item = (
self.model,
len(self._fields_for_select),
table,
self.model,
(None,),
)
if append_item not in self._select_related_idx:
self._select_related_idx.append(append_item)
db_fields_for_select = [
table[self.model._meta.fields_db_projection[field]].as_(field)
for field in self._fields_for_select
]
self.query = copy(self.model._meta.basequery).select(*db_fields_for_select)
else:
self.query = copy(self.model._meta.basequery_all_fields) # type:ignore[assignment]
append_item = (
self.model,
len(self.model._meta.db_fields) + len(self._annotations),
table,
self.model,
(None,),
)
if append_item not in self._select_related_idx:
self._select_related_idx.append(append_item)
self.resolve_ordering(
self.model, self.model._meta.basetable, self._orderings, self._annotations
)
self.resolve_filters()
if self._limit is not None:
self.query._limit = self.query._wrapper_cls(self._limit)
if self._offset is not None:
self.query._offset = self.query._wrapper_cls(self._offset)
if self._distinct:
self.query._distinct = True
if self._select_for_update:
self.query = self.query.for_update(
self._select_for_update_nowait,
self._select_for_update_skip_locked,
self._select_for_update_of,
)
if self._select_related:
for field in self._select_related:
field, __, forwarded_fields = field.partition("__")
self.query = self._join_table_with_select_related(
model=self.model,
table=self.model._meta.basetable,
field=field,
forwarded_fields=forwarded_fields,
path=(None, field),
)
if self._force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self._force_indexes)
if self._use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self._use_indexes)
def __await__(self) -> Generator[Any, None, List[MODEL]]:
if self._db is None:
self._db = self._choose_db(self._select_for_update) # type: ignore
self._make_query()
return self._execute().__await__()
async def __aiter__(self) -> AsyncIterator[MODEL]:
for val in await self:
yield val
async def _execute(self) -> List[MODEL]:
instance_list = await self._db.executor_class(
model=self.model,
db=self._db,
prefetch_map=self._prefetch_map,
prefetch_queries=self._prefetch_queries,
select_related_idx=self._select_related_idx, # type: ignore
).execute_select(
*self.query.get_parameterized_sql(),
custom_fields=list(self._annotations.keys()),
)
if self._single:
if len(instance_list) == 1:
return instance_list[0]
if not instance_list:
if self._raise_does_not_exist:
raise DoesNotExist(self.model)
return None # type: ignore
raise MultipleObjectsReturned(self.model)
return instance_list
[docs]class UpdateQuery(AwaitableQuery):
__slots__ = (
"update_kwargs",
"_orderings",
"_limit",
"values",
)
def __init__(
self,
model: Type[MODEL],
update_kwargs: Dict[str, Any],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
) -> None:
super().__init__(model)
self.update_kwargs = update_kwargs
self._q_objects = q_objects
self._annotations = annotations
self._custom_filters = custom_filters
self._db = db
self._limit = limit
self._orderings = orderings
def _make_query(self) -> None:
table = self.model._meta.basetable
self.query = self._db.query_class.update(table)
if self.capabilities.support_update_limit_order_by and self._limit:
self.query._limit = self.query._wrapper_cls(self._limit)
self.resolve_ordering(self.model, table, self._orderings, self._annotations)
self.resolve_filters()
# Need to get executor to get correct column_map
executor = self._db.executor_class(model=self.model, db=self._db)
for key, value in self.update_kwargs.items():
field_object = self.model._meta.fields_map.get(key)
if not field_object:
raise FieldError(f"Unknown keyword argument {key} for model {self.model}")
if field_object.pk:
raise IntegrityError(f"Field {key} is PK and can not be updated")
if isinstance(field_object, (ForeignKeyFieldInstance, OneToOneFieldInstance)):
self.model._validate_relation_type(key, value)
fk_field: str = field_object.source_field # type: ignore
db_field = self.model._meta.fields_map[fk_field].source_field
value = executor.column_map[fk_field](
getattr(value, field_object.to_field_instance.model_field_name),
None,
)
else:
try:
db_field = self.model._meta.fields_db_projection[key]
except KeyError:
raise FieldError(f"Field {key} is virtual and can not be updated")
if isinstance(value, Expression):
value = value.resolve(
ResolveContext(
model=self.model,
table=table,
annotations=self._annotations,
custom_filters=self._custom_filters,
)
).term
else:
value = executor.column_map[key](value, None)
self.query = self.query.set(db_field, value)
def __await__(self) -> Generator[Any, None, int]:
self._choose_db_if_not_chosen(True)
self._make_query()
return self._execute().__await__()
async def _execute(self) -> int:
return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0]
[docs]class DeleteQuery(AwaitableQuery):
__slots__ = (
"_annotations",
"_custom_filters",
"_orderings",
"_limit",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
) -> None:
super().__init__(model)
self._q_objects = q_objects
self._annotations = annotations
self._custom_filters = custom_filters
self._db = db
self._limit = limit
self._orderings = orderings
def _make_query(self) -> None:
self.query = copy(self.model._meta.basequery)
if self.capabilities.support_update_limit_order_by and self._limit:
self.query._limit = self.query._wrapper_cls(self._limit)
self.resolve_ordering(
model=self.model,
table=self.model._meta.basetable,
orderings=self._orderings,
annotations=self._annotations,
)
self.resolve_filters()
self.query._delete_from = True
return
def __await__(self) -> Generator[Any, None, int]:
self._choose_db_if_not_chosen(True)
self._make_query()
return self._execute().__await__()
async def _execute(self) -> int:
return (await self._db.execute_query(*self.query.get_parameterized_sql()))[0]
[docs]class ExistsQuery(AwaitableQuery):
__slots__ = (
"_force_indexes",
"_use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model)
self._q_objects = q_objects
self._db = db
self._annotations = annotations
self._custom_filters = custom_filters
self._force_indexes = force_indexes
self._use_indexes = use_indexes
def _make_query(self) -> None:
self.query = copy(self.model._meta.basequery)
self.resolve_filters()
self.query._limit = self.query._wrapper_cls(1)
self.query._select_other(ValueWrapper(1, allow_parametrize=False)) # type:ignore[arg-type]
if self._force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self._force_indexes)
if self._use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self._use_indexes)
def __await__(self) -> Generator[Any, None, bool]:
self._choose_db_if_not_chosen()
self._make_query()
return self._execute().__await__()
async def _execute(
self,
) -> bool:
result, _ = await self._db.execute_query(*self.query.get_parameterized_sql())
return bool(result)
[docs]class CountQuery(AwaitableQuery):
__slots__ = (
"_limit",
"_offset",
"_force_indexes",
"_use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
offset: Optional[int],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model)
self._q_objects = q_objects
self._annotations = annotations
self._custom_filters = custom_filters
self._limit = limit
self._offset = offset or 0
self._db = db
self._force_indexes = force_indexes
self._use_indexes = use_indexes
def _make_query(self) -> None:
self.query = copy(self.model._meta.basequery)
self.resolve_filters()
count_term = Count(Star())
if self.query._groupbys:
count_term = count_term.over()
# remove annotations
self.query._selects = []
self.query._select_other(count_term)
if self._force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self._force_indexes)
if self._use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self._use_indexes)
def __await__(self) -> Generator[Any, None, int]:
self._choose_db_if_not_chosen()
self._make_query()
return self._execute().__await__()
async def _execute(self) -> int:
_, result = await self._db.execute_query(*self.query.get_parameterized_sql())
if not result:
return 0
count = list(dict(result[0]).values())[0] - self._offset
if self._limit and count > self._limit:
return self._limit
return count
[docs]class FieldSelectQuery(AwaitableQuery):
# pylint: disable=W0223
def __init__(self, model: Type[MODEL], annotations: Dict[str, Any]) -> None:
super().__init__(model)
self._annotations = annotations
def _join_table_with_forwarded_fields(
self, model: Type[MODEL], table: Table, field: str, forwarded_fields: str
) -> Tuple[Table, str]:
if field in model._meta.fields_db_projection and not forwarded_fields:
return table, model._meta.fields_db_projection[field]
if field in model._meta.fields_db_projection and forwarded_fields:
raise FieldError(f'Field "{field}" for model "{model.__name__}" is not relation')
if field in self.model._meta.fetch_fields and not forwarded_fields:
raise ValueError(
'Selecting relation "{}" is not possible, select concrete '
"field on related model".format(field)
)
field_object = cast(RelationalField, model._meta.fields_map.get(field))
if not field_object:
raise FieldError(f'Unknown field "{field}" for model "{model.__name__}"')
table = self._join_table_by_field(table, field, field_object)
field, __, forwarded_fields_ = forwarded_fields.partition("__")
return self._join_table_with_forwarded_fields(
model=field_object.related_model,
table=table,
field=field,
forwarded_fields=forwarded_fields_,
)
def add_field_to_select_query(self, field: str, return_as: str) -> None:
table = self.model._meta.basetable
if field in self._annotations:
self._annotations[return_as] = self._annotations[field]
return
if field in self.model._meta.fields_db_projection:
db_field = self.model._meta.fields_db_projection[field]
self.query._select_field(table[db_field].as_(return_as))
return
if field in self.model._meta.fetch_fields:
raise ValueError(
'Selecting relation "{}" is not possible, select '
"concrete field on related model".format(field)
)
field_, __, forwarded_fields = field.partition("__")
if field_ in self.model._meta.fetch_fields:
related_table, related_db_field = self._join_table_with_forwarded_fields(
model=self.model,
table=table,
field=field_,
forwarded_fields=forwarded_fields,
)
self.query._select_field(related_table[related_db_field].as_(return_as))
return
raise FieldError(f'Unknown field "{field}" for model "{self.model.__name__}"')
def resolve_to_python_value(self, model: Type[MODEL], field: str) -> Callable:
if field in model._meta.fetch_fields:
# return as is to get whole model objects
return lambda x: x
if field in (x[1] for x in model._meta.db_native_fields):
return lambda x: x
if field in self._annotations:
annotation = self._annotations[field]
field_object = getattr(annotation, "field_object", None)
if field_object:
return field_object.to_python_value
return lambda x: x
if field in model._meta.fields_map:
return model._meta.fields_map[field].to_python_value
field_, __, forwarded_fields = field.partition("__")
if field_ in model._meta.fetch_fields:
new_model = model._meta.fields_map[field_].related_model # type: ignore
return self.resolve_to_python_value(new_model, forwarded_fields)
raise FieldError(f'Unknown field "{field}" for model "{model}"')
def _resolve_group_bys(self, *field_names: str) -> List:
group_bys = []
for field_name in field_names:
if field_name in self._annotations:
group_bys.append(Term(field_name))
continue
field, __, forwarded_fields = field_name.partition("__")
related_table, related_db_field = self._join_table_with_forwarded_fields(
model=self.model,
table=self.model._meta.basetable,
field=field,
forwarded_fields=forwarded_fields,
)
field = related_table[related_db_field].as_(
f"{related_table.get_table_name()}__{field_name}"
)
group_bys.append(field)
return group_bys
[docs]class ValuesListQuery(FieldSelectQuery, Generic[SINGLE]):
__slots__ = (
"fields",
"_limit",
"_offset",
"_distinct",
"_orderings",
"_single",
"_raise_does_not_exist",
"_fields_for_select_list",
"_flat",
"_group_bys",
"_force_indexes",
"_use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
single: bool,
raise_does_not_exist: bool,
fields_for_select_list: Union[Tuple[str, ...], List[str]],
limit: Optional[int],
offset: Optional[int],
distinct: bool,
orderings: List[Tuple[str, str]],
flat: bool,
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
group_bys: Tuple[str, ...],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model, annotations)
if flat and (len(fields_for_select_list) != 1):
raise TypeError("You can flat value_list only if contains one field")
fields_for_select = {str(i): field for i, field in enumerate(fields_for_select_list)}
self.fields = fields_for_select
self._limit = limit
self._offset = offset
self._distinct = distinct
self._orderings = orderings
self._custom_filters = custom_filters
self._q_objects = q_objects
self._single = single
self._raise_does_not_exist = raise_does_not_exist
self._fields_for_select_list = fields_for_select_list
self._flat = flat
self._db = db
self._group_bys = group_bys
self._force_indexes = force_indexes
self._use_indexes = use_indexes
def _make_query(self) -> None:
self._joined_tables = []
self.query = copy(self.model._meta.basequery)
for positional_number, field in self.fields.items():
self.add_field_to_select_query(field, positional_number)
self.resolve_ordering(
model=self.model,
table=self.model._meta.basetable,
orderings=self._orderings,
annotations=self._annotations,
)
self.resolve_filters()
if self._limit:
self.query._limit = self.query._wrapper_cls(self._limit)
if self._offset:
self.query._offset = self.query._wrapper_cls(self._offset)
if self._distinct:
self.query._distinct = True
if self._group_bys:
self.query._groupbys = self._resolve_group_bys(*self._group_bys)
if self._force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self._force_indexes)
if self._use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self._use_indexes)
@overload
def __await__(
self: "ValuesListQuery[Literal[False]]",
) -> Generator[Any, None, List[Tuple[Any, ...]]]: ...
@overload
def __await__(
self: "ValuesListQuery[Literal[True]]",
) -> Generator[Any, None, Tuple[Any, ...]]: ...
def __await__(self) -> Generator[Any, None, Union[List[Any], Tuple[Any, ...]]]:
self._choose_db_if_not_chosen()
self._make_query()
return self._execute().__await__() # pylint: disable=E1101
async def __aiter__(self: "ValuesListQuery[Any]") -> AsyncIterator[Any]:
for val in await self:
yield val
async def _execute(self) -> Union[List[Any], Tuple]:
_, result = await self._db.execute_query(*self.query.get_parameterized_sql())
columns = [
(key, self.resolve_to_python_value(self.model, name))
for key, name in self.fields.items()
]
if self._flat:
func = columns[0][1]
flatmap = lambda entry: func(entry["0"]) # noqa
lst_values = list(map(flatmap, result))
else:
listmap = lambda entry: tuple(func(entry[column]) for column, func in columns) # noqa
lst_values = list(map(listmap, result))
if self._single:
if len(lst_values) == 1:
return lst_values[0]
if not lst_values:
if self._raise_does_not_exist:
raise DoesNotExist(self.model)
return None # type: ignore
raise MultipleObjectsReturned(self.model)
return lst_values
[docs]class ValuesQuery(FieldSelectQuery, Generic[SINGLE]):
__slots__ = (
"_fields_for_select",
"_limit",
"_offset",
"_distinct",
"_orderings",
"_single",
"_raise_does_not_exist",
"_group_bys",
"_force_indexes",
"_use_indexes",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
single: bool,
raise_does_not_exist: bool,
fields_for_select: Dict[str, str],
limit: Optional[int],
offset: Optional[int],
distinct: bool,
orderings: List[Tuple[str, str]],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
group_bys: Tuple[str, ...],
force_indexes: Set[str],
use_indexes: Set[str],
) -> None:
super().__init__(model, annotations)
self._fields_for_select = fields_for_select
self._limit = limit
self._offset = offset
self._distinct = distinct
self._orderings = orderings
self._custom_filters = custom_filters
self._q_objects = q_objects
self._single = single
self._raise_does_not_exist = raise_does_not_exist
self._db = db
self._group_bys = group_bys
self._force_indexes = force_indexes
self._use_indexes = use_indexes
def _make_query(self) -> None:
self._joined_tables = []
self.query = copy(self.model._meta.basequery)
for return_as, field in self._fields_for_select.items():
self.add_field_to_select_query(field, return_as)
self.resolve_ordering(
model=self.model,
table=self.model._meta.basetable,
orderings=self._orderings,
annotations=self._annotations,
)
self.resolve_filters()
# remove annotations that are not in fields_for_select
self.query._selects = [
select for select in self.query._selects if select.alias in self._fields_for_select
]
if self._limit:
self.query._limit = self.query._wrapper_cls(self._limit)
if self._offset:
self.query._offset = self.query._wrapper_cls(self._offset)
if self._distinct:
self.query._distinct = True
if self._group_bys:
self.query._groupbys = self._resolve_group_bys(*self._group_bys)
if self._force_indexes:
self.query._force_indexes = []
self.query = self.query.force_index(*self._force_indexes)
if self._use_indexes:
self.query._use_indexes = []
self.query = self.query.use_index(*self._use_indexes)
@overload
def __await__(
self: "ValuesQuery[Literal[False]]",
) -> Generator[Any, None, List[Dict[str, Any]]]: ...
@overload
def __await__(
self: "ValuesQuery[Literal[True]]",
) -> Generator[Any, None, Dict[str, Any]]: ...
def __await__(
self,
) -> Generator[Any, None, Union[List[Dict[str, Any]], Dict[str, Any]]]:
self._choose_db_if_not_chosen()
self._make_query()
return self._execute().__await__() # pylint: disable=E1101
async def __aiter__(self: "ValuesQuery[Any]") -> AsyncIterator[Dict[str, Any]]:
for val in await self:
yield val
async def _execute(self) -> Union[List[dict], Dict]:
result = await self._db.execute_query_dict(*self.query.get_parameterized_sql())
columns = [
val
for val in [
(alias, self.resolve_to_python_value(self.model, field_name))
for alias, field_name in self._fields_for_select.items()
]
if not isinstance(val[1], types.LambdaType)
]
if columns:
for row in result:
for col, func in columns:
row[col] = func(row[col])
if self._single:
if len(result) == 1:
return result[0]
if not result:
if self._raise_does_not_exist:
raise DoesNotExist(self.model)
return None # type: ignore
raise MultipleObjectsReturned(self.model)
return result
[docs]class RawSQLQuery(AwaitableQuery):
__slots__ = ("_sql", "_db")
def __init__(self, model: Type[MODEL], db: BaseDBAsyncClient, sql: str) -> None:
super().__init__(model)
self._sql = sql
self._db = db
async def _execute(self) -> Any:
instance_list = await self._db.executor_class(
model=self.model,
db=self._db,
).execute_select(RawSQL(self._sql).get_sql(), [])
return instance_list
def __await__(self) -> Generator[Any, None, List[MODEL]]:
self._choose_db_if_not_chosen()
return self._execute().__await__()
[docs]class BulkUpdateQuery(UpdateQuery, Generic[MODEL]):
__slots__ = ("fields", "_objects", "_batch_size", "_queries")
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
q_objects: List[Q],
annotations: Dict[str, Any],
custom_filters: Dict[str, FilterInfoDict],
limit: Optional[int],
orderings: List[Tuple[str, str]],
objects: Iterable[MODEL],
fields: Iterable[str],
batch_size: Optional[int] = None,
):
super().__init__(
model,
update_kwargs={},
db=db,
q_objects=q_objects,
annotations=annotations,
custom_filters=custom_filters,
limit=limit,
orderings=orderings,
)
self.fields = fields
self._objects = objects
self._batch_size = batch_size
self._queries: List[QueryBuilder] = []
def _make_queries(self) -> List[Tuple[str, List[Any]]]:
table = self.model._meta.basetable
self.query = self._db.query_class.update(table)
if self.capabilities.support_update_limit_order_by and self._limit:
self.query._limit = self.query._wrapper_cls(self._limit)
self.resolve_ordering(
model=self.model,
table=table,
orderings=self._orderings,
annotations=self._annotations,
)
self.resolve_filters()
executor = self._db.executor_class(model=self.model, db=self._db)
pk_attr = self.model._meta.pk_attr
source_pk_attr = self.model._meta.fields_map[pk_attr].source_field or pk_attr
pk = Field(source_pk_attr)
for objects_item in chunk(self._objects, self._batch_size):
query = copy(self.query)
for field in self.fields:
case = Case()
pk_list = []
for obj in objects_item:
pk_value = executor.column_map[pk_attr](obj.pk, None)
field_obj = obj._meta.fields_map[field]
field_value = field_obj.to_db_value(getattr(obj, field), obj)
case.when(
pk == pk_value,
(
Cast(
self.query._wrapper_cls(field_value),
field_obj.get_for_dialect(
self._db.schema_generator.DIALECT, "SQL_TYPE"
),
)
if self._db.schema_generator.DIALECT == "postgres"
else self.query._wrapper_cls(field_value)
),
)
pk_list.append(pk_value)
query = query.set(field, case)
query = query.where(pk.isin(pk_list))
self._queries.append(query)
return [query.get_parameterized_sql() for query in self._queries]
async def _execute_many(self, queries_with_params: List[Tuple[str, List[Any]]]) -> int:
count = 0
for sql, values in queries_with_params:
count += (await self._db.execute_query(sql, values))[0]
return count
def __await__(self) -> Generator[Any, Any, int]:
self._choose_db_if_not_chosen(True)
queries = self._make_queries()
return self._execute_many(queries).__await__()
[docs] def sql(self, params_inline=False) -> str:
self._choose_db_if_not_chosen()
queries = self._make_queries()
return ";".join([sql for sql, _ in queries])
[docs]class BulkCreateQuery(AwaitableQuery, Generic[MODEL]):
__slots__ = (
"_objects",
"_ignore_conflicts",
"_batch_size",
"_db",
"_executor",
"_update_fields",
"_on_conflict",
)
def __init__(
self,
model: Type[MODEL],
db: BaseDBAsyncClient,
objects: Iterable[MODEL],
batch_size: Optional[int] = None,
ignore_conflicts: bool = False,
update_fields: Optional[Iterable[str]] = None,
on_conflict: Optional[Iterable[str]] = None,
):
super().__init__(model)
self._objects = objects
self._ignore_conflicts = ignore_conflicts
self._batch_size = batch_size
self._db = db
self._update_fields = update_fields
self._on_conflict = on_conflict
def _make_queries(self) -> Tuple[str, str]:
self._executor = self._db.executor_class(model=self.model, db=self._db)
if self._ignore_conflicts or self._update_fields:
_, columns = self._executor._prepare_insert_columns()
insert_query = self._executor._prepare_insert_statement(
columns, ignore_conflicts=self._ignore_conflicts
)
insert_query_all = insert_query
if self.model._meta.generated_db_fields:
_, columns_all = self._executor._prepare_insert_columns(include_generated=True)
insert_query_all = self._executor._prepare_insert_statement(
columns_all,
has_generated=False,
ignore_conflicts=self._ignore_conflicts,
)
if self._update_fields:
alias = f"new_{self.model._meta.db_table}"
insert_query_all = insert_query_all.as_(alias).on_conflict(
*(self._on_conflict or [])
)
insert_query = insert_query.as_(alias).on_conflict(*(self._on_conflict or []))
for update_field in self._update_fields:
insert_query_all = insert_query_all.do_update(update_field)
insert_query = insert_query.do_update(update_field)
return insert_query.get_sql(), insert_query_all.get_sql()
else:
return self._executor.insert_query, self._executor.insert_query_all
async def _execute_many(self, insert_sql: str, insert_sql_all: str) -> None:
for instance_chunk in chunk(self._objects, self._batch_size):
values_lists_all = []
values_lists = []
for instance in instance_chunk:
if instance._custom_generated_pk:
values_lists_all.append(
[
self._executor.column_map[field_name](
getattr(instance, field_name), instance
)
for field_name in self._executor.regular_columns_all
]
)
else:
values_lists.append(
[
self._executor.column_map[field_name](
getattr(instance, field_name), instance
)
for field_name in self._executor.regular_columns
]
)
if values_lists_all:
await self._db.execute_many(insert_sql_all, values_lists_all)
if values_lists:
await self._db.execute_many(insert_sql, values_lists)
def __await__(self) -> Generator[Any, None, None]:
self._choose_db_if_not_chosen(True)
insert_sql, insert_sql_all = self._make_queries()
return self._execute_many(insert_sql, insert_sql_all).__await__()
[docs] def sql(self, params_inline=False) -> str:
self._choose_db_if_not_chosen()
insert_sql, insert_sql_all = self._make_queries()
if all(o._custom_generated_pk for o in self._objects):
return insert_sql_all
if all(not o._custom_generated_pk for o in self._objects):
return insert_sql
return ";".join([insert_sql, insert_sql_all])