from __future__ import annotations
import operator
from collections.abc import Iterator
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import Enum
from typing import TYPE_CHECKING, Any, cast
from pypika_tortoise import Case as PypikaCase
from pypika_tortoise import Field as PypikaField
from pypika_tortoise import SqlContext, Table
from pypika_tortoise.functions import AggregateFunction, DistinctOptionFunction
from pypika_tortoise.terms import (
ArithmeticExpression,
Criterion,
)
from pypika_tortoise.terms import Function as PypikaFunction
from pypika_tortoise.terms import (
Term,
ValueWrapper,
)
from pypika_tortoise.utils import format_alias_sql
from tortoise.exceptions import FieldError, OperationalError
from tortoise.fields.base import Field
from tortoise.fields.data import JSONField
from tortoise.fields.relational import RelationalField
from tortoise.filters import FilterInfoDict
from tortoise.query_utils import (
QueryModifier,
TableCriterionTuple,
get_joins_for_related_field,
resolve_field_json_path,
resolve_nested_field,
)
if TYPE_CHECKING: # pragma: nocoverage
from pypika_tortoise.queries import Selectable
from tortoise.models import Model
from tortoise.queryset import AwaitableQuery
@dataclass(frozen=True)
class ResolveContext:
model: type[Model]
table: Table
annotations: dict[str, Any]
custom_filters: dict[str, FilterInfoDict]
@dataclass
class ResolveResult:
term: Term
joins: list[TableCriterionTuple] = dataclass_field(default_factory=list)
output_field: Field | None = None
class Expression:
"""
Parent class for expressions
"""
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
raise NotImplementedError()
class Value(Expression):
"""
Wrapper for a value that should be used as a term in a query.
"""
def __init__(self, value: Any) -> None:
self.value = value
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
return ResolveResult(term=ValueWrapper(self.value))
class Connector(Enum):
add = "add"
sub = "sub"
mul = "mul"
div = "truediv"
pow = "pow"
mod = "mod"
class CombinedExpression(Expression):
def __init__(self, left: Expression, connector: Connector, right: Any) -> None:
self.left = left
self.connector = connector
self.right = right if isinstance(right, Expression) else Value(right)
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
left = self.left.resolve(resolve_context)
right = self.right.resolve(resolve_context)
left_output_field, right_output_field = left.output_field, right.output_field # type: ignore
if (
left_output_field
and right_output_field
and type(left_output_field) is not type(right_output_field)
):
raise FieldError("Cannot use arithmetic expression between different field type")
operator_func = getattr(operator, self.connector.value)
return ResolveResult(
term=operator_func(left.term, right.term),
joins=list(set(left.joins + right.joins)), # dedup joins
output_field=right_output_field or left_output_field,
)
class F(Expression):
"""
F() can be used to reference a model field, field of a related model, annotation or
an attribute of a JSON field. It can be used in the following ways:
- as a field reference, e.g. F("id")
- as a related field reference, e.g. F("related_field__field") will return the value of the field
of the related model.
- as a JSON field reference, e.g. F("json_field__attribute") will return the value of the "attribute"
property of the JSON field value. The reference can be nested, e.g. F("json_field__attribute__subattribute")
- as a JSON field array element reference, e.g. F("json_field__0") will return the first element of the array.
:param name: The name of the field to reference.
"""
def __init__(self, name: str) -> None:
self.name = name
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
term: Term
joins: list[TableCriterionTuple] = []
output_field = None
main_name_part, __, rest_name_parts = self.name.partition("__")
if main_name_part in resolve_context.model._meta.fetch_fields:
# field in the format of "related_field__field" or "related_field__another_rel_field__field"
term, joins, output_field = resolve_nested_field(
resolve_context.model, resolve_context.table, self.name
)
elif (
rest_name_parts
and main_name_part in resolve_context.model._meta.fields_map
and isinstance(resolve_context.model._meta.fields_map[main_name_part], JSONField)
):
# Accessing a JSON field, e.g. F("json_field__a__b")
key_parts = [
int(item) if item.isdigit() else str(item) for item in rest_name_parts.split("__")
]
term = resolve_field_json_path(
PypikaField(resolve_context.model._meta.fields_db_projection[main_name_part]),
key_parts,
)
elif self.name in resolve_context.annotations:
# reference to another annotation, e.g. M.annotate(f1=...).annotate(f2=F("f1")).values('field')
annotation = resolve_context.annotations[self.name]
if isinstance(annotation, Term):
term = annotation
else:
term = annotation.resolve(resolve_context).term
else:
# a regular model field, e.g. F("id")
try:
meta = resolve_context.model._meta
term = PypikaField(meta.fields_db_projection[self.name])
if (output_field := meta.fields_map.get(self.name, None)) and (
func := output_field.get_for_dialect(
meta.db.capabilities.dialect, "function_cast"
)
):
term = func(output_field, term)
except KeyError:
raise FieldError(
f"There is no non-virtual field {self.name} on Model {resolve_context.model.__name__}"
) from None
return ResolveResult(term=term, output_field=output_field, joins=joins)
def _combine(self, other: Any, connector: Connector, right_hand: bool) -> CombinedExpression:
if not isinstance(other, Expression):
other = Value(other)
if right_hand:
return CombinedExpression(other, connector, self)
return CombinedExpression(self, connector, other)
def __neg__(self) -> CombinedExpression:
return self._combine(-1, Connector.mul, False)
def __add__(self, other) -> CombinedExpression:
return self._combine(other, Connector.add, False)
def __sub__(self, other) -> CombinedExpression:
return self._combine(other, Connector.sub, False)
def __mul__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mul, False)
def __truediv__(self, other) -> CombinedExpression:
return self._combine(other, Connector.div, False)
def __mod__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mod, False)
def __pow__(self, other) -> CombinedExpression:
return self._combine(other, Connector.pow, False)
def __radd__(self, other) -> CombinedExpression:
return self._combine(other, Connector.add, True)
def __rsub__(self, other) -> CombinedExpression:
return self._combine(other, Connector.sub, True)
def __rmul__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mul, True)
def __rtruediv__(self, other) -> CombinedExpression:
return self._combine(other, Connector.div, True)
def __rmod__(self, other) -> CombinedExpression:
return self._combine(other, Connector.mod, True)
def __rpow__(self, other) -> CombinedExpression:
return self._combine(other, Connector.pow, True)
class Subquery(Term):
def __init__(self, query: AwaitableQuery) -> None:
super().__init__()
self.query = query
def get_sql(self, ctx: SqlContext) -> str:
self.query._choose_db_if_not_chosen()
self.query._make_query()
return self.query.query.get_parameterized_sql(ctx)[0]
def as_(self, alias: str) -> Selectable: # type: ignore
self.query._choose_db_if_not_chosen()
self.query._make_query()
return self.query.query.as_(alias)
class RawSQL(Term):
def __init__(self, sql: str) -> None:
super().__init__()
self.sql = sql
def get_sql(self, ctx: SqlContext) -> str:
if ctx.with_alias:
return format_alias_sql(sql=self.sql, alias=self.alias, ctx=ctx)
return self.sql
[docs]
class Q:
"""
Q Expression container.
Q Expressions are a useful tool to compose a query from many small parts.
:param join_type: Is the join an AND or OR join type?
:param args: Inner ``Q`` expressions that you want to wrap.
:param kwargs: Filter statements that this Q object should encapsulate.
"""
__slots__ = (
"children",
"filters",
"join_type",
"_is_negated",
)
AND = "AND"
OR = "OR"
def __init__(self, *args: Q, join_type: str = AND, **kwargs: Any) -> None:
if args and kwargs:
newarg = Q(join_type=join_type, **kwargs)
args = (newarg,) + args
kwargs = {}
if not all(isinstance(node, Q) for node in args):
raise OperationalError("All ordered arguments must be Q nodes")
#: Contains the sub-Q's that this Q is made up of
self.children: tuple[Q, ...] = args
#: Contains the filters applied to this Q
self.filters: dict[str, FilterInfoDict] = kwargs
if join_type not in {self.AND, self.OR}:
raise OperationalError("join_type must be AND or OR")
#: Specifies if this Q does an AND or OR on its children
self.join_type = join_type
self._is_negated = False
[docs]
def __and__(self, other: Q) -> Q:
"""
Returns a binary AND of Q objects, use ``AND`` operator.
:raises OperationalError: AND operation requires a Q node
"""
if not isinstance(other, Q):
raise OperationalError("AND operation requires a Q node")
return Q(self, other, join_type=self.AND)
[docs]
def __or__(self, other: Q) -> Q:
"""
Returns a binary OR of Q objects, use ``OR`` operator.
:raises OperationalError: OR operation requires a Q node
"""
if not isinstance(other, Q):
raise OperationalError("OR operation requires a Q node")
return Q(self, other, join_type=self.OR)
[docs]
def __invert__(self) -> Q:
"""
Returns a negated instance of the Q object, use ``~`` operator.
"""
q = Q(*self.children, join_type=self.join_type, **self.filters)
q.negate()
return q
def __eq__(self, other: object) -> bool:
if not isinstance(other, Q):
return False
return (
self.children == other.children
and self.join_type == other.join_type
and self.filters == other.filters
)
[docs]
def negate(self) -> None:
"""
Negates the current Q object. (mutation)
"""
self._is_negated = not self._is_negated
def _resolve_nested_filter(
self, resolve_context: ResolveContext, key: str, value: Any, table: Table
) -> QueryModifier:
related_field_name, __, forwarded_fields = key.partition("__")
related_field = cast(
RelationalField, resolve_context.model._meta.fields_map[related_field_name]
)
required_joins = get_joins_for_related_field(table, related_field, related_field_name)
q = Q(**{forwarded_fields: value})
modifier = q.resolve(
ResolveContext(
model=related_field.related_model,
table=required_joins[-1][0],
annotations=resolve_context.annotations,
custom_filters=resolve_context.custom_filters,
)
)
return QueryModifier(joins=required_joins) & modifier
def _resolve_custom_kwarg(
self, resolve_context: ResolveContext, key: str, value: Any, table: Table
) -> QueryModifier:
having_info = resolve_context.custom_filters[key]
annotation = resolve_context.annotations[having_info["field"]]
if isinstance(annotation, Term):
annotation_info = ResolveResult(term=annotation)
else:
annotation_info = annotation.resolve(resolve_context)
operator = having_info["operator"]
overridden_operator = (
resolve_context.model._meta.db.executor_class.get_overridden_filter_func(
filter_func=operator
)
)
if overridden_operator:
operator = overridden_operator
if annotation_info.term.is_aggregate:
modifier = QueryModifier(having_criterion=operator(annotation_info.term, value))
else:
modifier = QueryModifier(where_criterion=operator(annotation_info.term, value))
return modifier
def _process_filter_kwarg(
self, model: type[Model], key: str, value: Any, table: Table
) -> tuple[Criterion, tuple[Table, Criterion] | None]:
join = None
if value is None and f"{key}__isnull" in model._meta.filters:
filter_info = model._meta.get_filter(f"{key}__isnull")
value = True
else:
filter_info = model._meta.get_filter(key)
if "table" in filter_info:
# join the table
join = (
filter_info["table"],
table[model._meta.db_pk_column]
== filter_info["table"][filter_info["backward_key"]],
)
if "value_encoder" in filter_info:
value = filter_info["value_encoder"](value, model)
table = filter_info["table"]
elif not isinstance(value, Term):
field_object = model._meta.fields_map[filter_info["field"]]
value = (
filter_info["value_encoder"](value, model, field_object)
if "value_encoder" in filter_info
else field_object.to_db_value(value, model)
)
op = filter_info["operator"]
criterion = op(table[filter_info.get("source_field", filter_info["field"])], value)
return criterion, join
def _resolve_regular_kwarg(
self, resolve_context: ResolveContext, key: str, value: Any, table: Table
) -> QueryModifier:
if (
key not in resolve_context.model._meta.filters
and key.split("__")[0] in resolve_context.model._meta.fetch_fields
):
modifier = self._resolve_nested_filter(resolve_context, key, value, table)
else:
criterion, join = self._process_filter_kwarg(resolve_context.model, key, value, table)
joins = [join] if join else []
modifier = QueryModifier(where_criterion=criterion, joins=joins)
return modifier
def _get_actual_filter_params(
self, resolve_context: ResolveContext, key: str, value: Table | FilterInfoDict
) -> tuple[str, Any]:
filter_key = key
if (
key in resolve_context.model._meta.fk_fields
or key in resolve_context.model._meta.o2o_fields
):
field_object = resolve_context.model._meta.fields_map[key]
filter_key = cast(str, field_object.source_field)
filter_value = getattr(value, "pk", value)
elif key in resolve_context.model._meta.m2m_fields:
filter_value = getattr(value, "pk", value)
elif (
key.split("__")[0] in resolve_context.model._meta.fetch_fields
or key in resolve_context.custom_filters
or key in resolve_context.model._meta.filters
):
filter_value = value
else:
allowed = sorted(
resolve_context.model._meta.fields
| resolve_context.model._meta.fetch_fields
| set(resolve_context.custom_filters)
)
raise FieldError(f"Unknown filter param '{key}'. Allowed base values are {allowed}")
if isinstance(filter_value, Expression):
filter_value = filter_value.resolve(resolve_context).term
return filter_key, filter_value
def _resolve_kwargs(self, resolve_context: ResolveContext) -> QueryModifier:
modifier = QueryModifier()
for raw_key, raw_value in self.filters.items():
key, value = self._get_actual_filter_params(resolve_context, raw_key, raw_value)
if key in resolve_context.custom_filters:
filter_modifier = self._resolve_custom_kwarg(
resolve_context, key, value, resolve_context.table
)
else:
filter_modifier = self._resolve_regular_kwarg(
resolve_context, key, value, resolve_context.table
)
if self.join_type == self.AND:
modifier &= filter_modifier
else:
modifier |= filter_modifier
if self._is_negated:
modifier = ~modifier
return modifier
def _resolve_children(self, resolve_context: ResolveContext) -> QueryModifier:
modifier = QueryModifier()
for node in self.children:
node_modifier = node.resolve(resolve_context)
if self.join_type == self.AND:
modifier &= node_modifier
else:
modifier |= node_modifier
if self._is_negated:
modifier = ~modifier
return modifier
[docs]
def resolve(
self,
resolve_context: ResolveContext,
) -> QueryModifier:
"""
Resolves the logical Q chain into the parts of a SQL statement.
:param model: The Model this Q Expression should be resolved on.
:param table: ``pypika_tortoise.Table`` to keep track of the virtual SQL table
(to allow self referential joins)
"""
if self.filters:
return self._resolve_kwargs(resolve_context)
return self._resolve_children(resolve_context)
[docs]
class Function(Expression):
"""
Function/Aggregate base.
:param field: Field name
:param default_values: Extra parameters to the function.
.. attribute:: database_func
:annotation: pypika_tortoise.terms.Function
The pypika function this represents.
.. attribute:: populate_field_object
:annotation: bool = False
Enable populate_field_object where we want to try and preserve the field type.
"""
__slots__ = ("field", "field_object", "default_values")
database_func: type[PypikaFunction] = PypikaFunction
# Enable populate_field_object where we want to try and preserve the field type.
populate_field_object = False
def __init__(
self, field: str | F | CombinedExpression | Function, *default_values: Any
) -> None:
self.field = field
self.field_object: Field | None = None
self.default_values = default_values
def _get_function_field(self, field: Term | str, *default_values) -> PypikaFunction:
return self.database_func(field, *default_values) # type:ignore[arg-type]
def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult:
term, joins, output_field = resolve_nested_field(
resolve_context.model, resolve_context.table, field
)
if self.populate_field_object:
self.field_object = output_field
return ResolveResult(term=term, joins=joins, output_field=output_field)
def _resolve_default_values(self, resolve_context: ResolveContext) -> Iterator[Any]:
for default_value in self.default_values:
if isinstance(default_value, Function):
yield default_value.resolve(resolve_context).term
else:
yield default_value
[docs]
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
"""
Used to resolve the Function statement for SQL generation.
:param model: Model the function is applied on to.
:param table: ``pypika_tortoise.Table`` to keep track of the virtual SQL table
(to allow self referential joins)
:return: Dict with keys ``"joins"`` and ``"fields"``
"""
default_values = self._resolve_default_values(resolve_context)
function_arg = (
self._resolve_nested_field(resolve_context, self.field)
if isinstance(self.field, str)
else self.field.resolve(resolve_context)
)
term = self._get_function_field(function_arg.term, *default_values)
res = ResolveResult(
term=term,
joins=function_arg.joins,
output_field=function_arg.output_field, # type:ignore[call-overload]
)
if self.populate_field_object and (
res_output_field := res.output_field # type:ignore[call-overload]
):
self.field_object = res_output_field
return res
[docs]
class Aggregate(Function):
"""
Base for SQL Aggregates.
:param field: Field name
:param default_values: Extra parameters to the function.
:param is_distinct: Flag for aggregate with distinction
"""
database_func: type[AggregateFunction] = DistinctOptionFunction
def __init__(
self,
field: str | F | CombinedExpression,
*default_values: Any,
distinct: bool = False,
_filter: Q | None = None,
) -> None:
super().__init__(field, *default_values)
self.distinct = distinct
self.filter = _filter
def _get_function_field( # type:ignore[override]
self, field: ArithmeticExpression | PypikaField | str, *default_values
) -> DistinctOptionFunction:
function = cast(DistinctOptionFunction, self.database_func(field, *default_values))
if self.distinct:
function = function.distinct()
return function
def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult:
ret = super()._resolve_nested_field(resolve_context, field)
if self.filter:
modifier = QueryModifier()
modifier &= self.filter.resolve(resolve_context)
ret.term = PypikaCase().when(modifier.where_criterion, ret.term).else_(None)
return ret
class _WhenThen(Term):
"""This is not a real term, but a helper to store the when and then terms."""
def __init__(self, when: Term, then: Term) -> None:
self.when = when
self.then = then
[docs]
class When(Expression):
"""
When expression.
:param args: Q objects
:param kwargs: keyword criterion like filter
:param then: value for criterion
:param negate: false (default)
"""
def __init__(
self,
*args: Q,
then: str | F | CombinedExpression | Function,
negate: bool = False,
**kwargs: Any,
) -> None:
self.args = args
self.then = then
self.negate = negate
self.kwargs = kwargs
def _resolve_q_objects(self) -> list[Q]:
q_objects = []
for arg in self.args:
if not isinstance(arg, Q):
raise TypeError("expected Q objects as args")
if self.negate:
q_objects.append(~arg)
else:
q_objects.append(arg)
for key, value in self.kwargs.items():
if self.negate:
q_objects.append(~Q(**{key: value}))
else:
q_objects.append(Q(**{key: value}))
return q_objects
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
q_objects = self._resolve_q_objects()
modifier = QueryModifier()
for node in q_objects:
modifier &= node.resolve(resolve_context)
if isinstance(self.then, Expression):
then = self.then.resolve(resolve_context).term
else:
then = cast(Term, Term.wrap_constant(self.then))
return ResolveResult(term=_WhenThen(modifier.where_criterion, then))
[docs]
class Case(Expression):
"""
Case expression.
:param args: When objects
:param default: value for 'CASE WHEN ... THEN ... ELSE <default> END'
"""
def __init__(
self,
*args: When,
default: str | F | CombinedExpression | Function | None = None,
) -> None:
self.args = args
self.default = default
def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
case = PypikaCase()
for arg in self.args:
if not isinstance(arg, When):
raise TypeError("expected When objects as args")
when = arg.resolve(resolve_context)
when_term = cast(_WhenThen, when.term)
case = case.when(when_term.when, when_term.then)
if isinstance(self.default, Expression):
case = case.else_(self.default.resolve(resolve_context).term)
else:
case = case.else_(Term.wrap_constant(self.default))
return ResolveResult(term=case)