Source code for tortoise.expressions

from __future__ import annotations

import operator
from dataclasses import dataclass
from dataclasses import field as dataclass_field
from enum import Enum
from typing import TYPE_CHECKING, Any, Iterator, Type, cast

from pypika import Case as PypikaCase
from pypika import Field as PypikaField
from pypika import Table
from pypika.functions import AggregateFunction, DistinctOptionFunction
from pypika.terms import ArithmeticExpression, Criterion
from pypika.terms import Function as PypikaFunction
from pypika.terms import Term, ValueWrapper
from pypika.utils import format_alias_sql

from tortoise.exceptions import FieldError, OperationalError
from tortoise.fields.base import Field
from tortoise.fields.relational import RelationalField
from tortoise.filters import FilterInfoDict
from tortoise.query_utils import (
    QueryModifier,
    TableCriterionTuple,
    get_joins_for_related_field,
    resolve_nested_field,
)

if TYPE_CHECKING:  # pragma: nocoverage
    from pypika.queries import Selectable

    from tortoise.models import Model
    from tortoise.queryset import AwaitableQuery


@dataclass(frozen=True)
class ResolveContext:
    model: Type["Model"]
    table: Table
    annotations: dict[str, Any]
    custom_filters: dict[str, FilterInfoDict]


@dataclass
class ResolveResult:
    term: Term
    joins: list[TableCriterionTuple] = dataclass_field(default_factory=list)
    output_field: Field | None = None


class Expression:
    """
    Parent class for expressions
    """

    def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
        raise NotImplementedError()


class Value(Expression):
    """
    Wrapper for a value that should be used as a term in a query.
    """

    def __init__(self, value: Any) -> None:
        self.value = value

    def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
        return ResolveResult(term=ValueWrapper(self.value))


class Connector(Enum):
    add = "add"
    sub = "sub"
    mul = "mul"
    div = "truediv"
    pow = "pow"
    mod = "mod"


class CombinedExpression(Expression):
    def __init__(self, left: Expression, connector: Connector, right: Any) -> None:
        self.left = left
        self.connector = connector
        self.right = right if isinstance(right, Expression) else Value(right)

    def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
        left = self.left.resolve(resolve_context)
        right = self.right.resolve(resolve_context)
        left_output_field, right_output_field = left.output_field, right.output_field  # type: ignore

        if (
            left_output_field
            and right_output_field
            and type(left_output_field) is not type(right_output_field)
        ):
            raise FieldError("Cannot use arithmetic expression between different field type")

        operator_func = getattr(operator, self.connector.value)
        return ResolveResult(
            term=operator_func(left.term, right.term),
            joins=list(set(left.joins + right.joins)),  # dedup joins
            output_field=right_output_field or left_output_field,
        )


class F(Expression):
    """
    An F() object represents a model field's value, its transformed value, or an annotated column.
    It enables referencing and performing database operations on model field values directly in
    the database, without needing to load them into Python memory.

    :param name: The name of the field to reference.
    """

    def __init__(self, name: str) -> None:
        self.name = name

    def resolve(self, resolve_context: ResolveContext) -> ResolveResult:
        term: Term = PypikaField(self.name)
        joins: list[TableCriterionTuple] = []
        output_field = None
        if self.name.split("__")[0] in resolve_context.model._meta.fetch_fields:
            # field in the format of "related_field__field" or "related_field__another_rel_field__field"
            term, joins, output_field = resolve_nested_field(
                resolve_context.model, resolve_context.table, self.name
            )
        elif self.name in resolve_context.annotations:
            # reference to another annotation, e.g. M.annotate(f1=...).annotate(f2=F("f1")).values('field')
            annotation = resolve_context.annotations[self.name]
            if isinstance(annotation, Term):
                term = annotation
            else:
                term = annotation.resolve(resolve_context).term
        else:
            # a regular model field, e.g. F("id")
            try:
                meta = resolve_context.model._meta
                term.name = meta.fields_db_projection[self.name]  # type:ignore[attr-defined]

                if (output_field := meta.fields_map.get(self.name, None)) and (
                    func := output_field.get_for_dialect(
                        meta.db.capabilities.dialect, "function_cast"
                    )
                ):
                    term = func(output_field, term)
            except KeyError:
                raise FieldError(
                    f"There is no non-virtual field {self.name} on Model {resolve_context.model.__name__}"
                ) from None
        return ResolveResult(term=term, output_field=output_field, joins=joins)

    def _combine(self, other: Any, connector: Connector, right_hand: bool) -> CombinedExpression:
        if not isinstance(other, Expression):
            other = Value(other)

        if right_hand:
            return CombinedExpression(other, connector, self)
        return CombinedExpression(self, connector, other)

    def __neg__(self) -> CombinedExpression:
        return self._combine(-1, Connector.mul, False)

    def __add__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.add, False)

    def __sub__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.sub, False)

    def __mul__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.mul, False)

    def __truediv__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.div, False)

    def __mod__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.mod, False)

    def __pow__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.pow, False)

    def __radd__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.add, True)

    def __rsub__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.sub, True)

    def __rmul__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.mul, True)

    def __rtruediv__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.div, True)

    def __rmod__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.mod, True)

    def __rpow__(self, other) -> CombinedExpression:
        return self._combine(other, Connector.pow, True)


class Subquery(Term):
    def __init__(self, query: "AwaitableQuery") -> None:
        super().__init__()
        self.query = query

    def get_sql(self, **kwargs: Any) -> str:
        self.query._choose_db_if_not_chosen()
        self.query._make_query()
        return self.query.query.get_parameterized_sql(**kwargs)[0]

    def as_(self, alias: str) -> "Selectable":  # type: ignore
        self.query._choose_db_if_not_chosen()
        self.query._make_query()
        return self.query.query.as_(alias)


class RawSQL(Term):
    def __init__(self, sql: str) -> None:
        super().__init__()
        self.sql = sql

    def get_sql(self, with_alias: bool = False, **kwargs: Any) -> str:
        if with_alias:
            return format_alias_sql(sql=self.sql, alias=self.alias, **kwargs)
        return self.sql


[docs]class Q: """ Q Expression container. Q Expressions are a useful tool to compose a query from many small parts. :param join_type: Is the join an AND or OR join type? :param args: Inner ``Q`` expressions that you want to wrap. :param kwargs: Filter statements that this Q object should encapsulate. """ __slots__ = ( "children", "filters", "join_type", "_is_negated", ) AND = "AND" OR = "OR" def __init__(self, *args: "Q", join_type: str = AND, **kwargs: Any) -> None: if args and kwargs: newarg = Q(join_type=join_type, **kwargs) args = (newarg,) + args kwargs = {} if not all(isinstance(node, Q) for node in args): raise OperationalError("All ordered arguments must be Q nodes") #: Contains the sub-Q's that this Q is made up of self.children: tuple[Q, ...] = args #: Contains the filters applied to this Q self.filters: dict[str, FilterInfoDict] = kwargs if join_type not in {self.AND, self.OR}: raise OperationalError("join_type must be AND or OR") #: Specifies if this Q does an AND or OR on its children self.join_type = join_type self._is_negated = False
[docs] def __and__(self, other: "Q") -> "Q": """ Returns a binary AND of Q objects, use ``AND`` operator. :raises OperationalError: AND operation requires a Q node """ if not isinstance(other, Q): raise OperationalError("AND operation requires a Q node") return Q(self, other, join_type=self.AND)
[docs] def __or__(self, other: "Q") -> "Q": """ Returns a binary OR of Q objects, use ``OR`` operator. :raises OperationalError: OR operation requires a Q node """ if not isinstance(other, Q): raise OperationalError("OR operation requires a Q node") return Q(self, other, join_type=self.OR)
[docs] def __invert__(self) -> "Q": """ Returns a negated instance of the Q object, use ``~`` operator. """ q = Q(*self.children, join_type=self.join_type, **self.filters) q.negate() return q
def __eq__(self, other: object) -> bool: if not isinstance(other, Q): return False return ( self.children == other.children and self.join_type == other.join_type and self.filters == other.filters )
[docs] def negate(self) -> None: """ Negates the current Q object. (mutation) """ self._is_negated = not self._is_negated
def _resolve_nested_filter( self, resolve_context: ResolveContext, key: str, value: Any, table: Table ) -> QueryModifier: related_field_name, __, forwarded_fields = key.partition("__") related_field = cast( RelationalField, resolve_context.model._meta.fields_map[related_field_name] ) required_joins = get_joins_for_related_field(table, related_field, related_field_name) q = Q(**{forwarded_fields: value}) modifier = q.resolve( ResolveContext( model=related_field.related_model, table=required_joins[-1][0], annotations=resolve_context.annotations, custom_filters=resolve_context.custom_filters, ) ) return QueryModifier(joins=required_joins) & modifier def _resolve_custom_kwarg( self, resolve_context: ResolveContext, key: str, value: Any, table: Table ) -> QueryModifier: having_info = resolve_context.custom_filters[key] annotation = resolve_context.annotations[having_info["field"]] if isinstance(annotation, Term): annotation_info = ResolveResult(term=annotation) else: annotation_info = annotation.resolve(resolve_context) operator = having_info["operator"] overridden_operator = ( resolve_context.model._meta.db.executor_class.get_overridden_filter_func( filter_func=operator ) ) if overridden_operator: operator = overridden_operator if annotation_info.term.is_aggregate: modifier = QueryModifier(having_criterion=operator(annotation_info.term, value)) else: modifier = QueryModifier(where_criterion=operator(annotation_info.term, value)) return modifier def _process_filter_kwarg( self, model: "Type[Model]", key: str, value: Any, table: Table ) -> tuple[Criterion, tuple[Table, Criterion] | None]: join = None if value is None and f"{key}__isnull" in model._meta.filters: param = model._meta.get_filter(f"{key}__isnull") value = True else: param = model._meta.get_filter(key) pk_db_field = model._meta.db_pk_column if param.get("table"): join = ( param["table"], table[pk_db_field] == param["table"][param["backward_key"]], ) if param.get("value_encoder"): value = param["value_encoder"](value, model) op = param["operator"] criterion = op(param["table"][param["field"]], value) else: if isinstance(value, Term): encoded_value = value else: field_object = model._meta.fields_map[param["field"]] encoded_value = ( param["value_encoder"](value, model, field_object) if param.get("value_encoder") else field_object.to_db_value(value, model) ) op = param["operator"] criterion = op(table[param["source_field"]], encoded_value) return criterion, join def _resolve_regular_kwarg( self, resolve_context: ResolveContext, key: str, value: Any, table: Table ) -> QueryModifier: if ( key not in resolve_context.model._meta.filters and key.split("__")[0] in resolve_context.model._meta.fetch_fields ): modifier = self._resolve_nested_filter(resolve_context, key, value, table) else: criterion, join = self._process_filter_kwarg(resolve_context.model, key, value, table) joins = [join] if join else [] modifier = QueryModifier(where_criterion=criterion, joins=joins) return modifier def _get_actual_filter_params( self, resolve_context: ResolveContext, key: str, value: Table | FilterInfoDict ) -> tuple[str, Any]: filter_key = key if ( key in resolve_context.model._meta.fk_fields or key in resolve_context.model._meta.o2o_fields ): field_object = resolve_context.model._meta.fields_map[key] filter_key = cast(str, field_object.source_field) filter_value = getattr(value, "pk", value) elif key in resolve_context.model._meta.m2m_fields: filter_value = getattr(value, "pk", value) elif ( key.split("__")[0] in resolve_context.model._meta.fetch_fields or key in resolve_context.custom_filters or key in resolve_context.model._meta.filters ): filter_value = value else: allowed = sorted( resolve_context.model._meta.fields | resolve_context.model._meta.fetch_fields | set(resolve_context.custom_filters) ) raise FieldError(f"Unknown filter param '{key}'. Allowed base values are {allowed}") if isinstance(filter_value, Expression): filter_value = filter_value.resolve(resolve_context).term return filter_key, filter_value def _resolve_kwargs(self, resolve_context: ResolveContext) -> QueryModifier: modifier = QueryModifier() for raw_key, raw_value in self.filters.items(): key, value = self._get_actual_filter_params(resolve_context, raw_key, raw_value) if key in resolve_context.custom_filters: filter_modifier = self._resolve_custom_kwarg( resolve_context, key, value, resolve_context.table ) else: filter_modifier = self._resolve_regular_kwarg( resolve_context, key, value, resolve_context.table ) if self.join_type == self.AND: modifier &= filter_modifier else: modifier |= filter_modifier if self._is_negated: modifier = ~modifier return modifier def _resolve_children(self, resolve_context: ResolveContext) -> QueryModifier: modifier = QueryModifier() for node in self.children: node_modifier = node.resolve(resolve_context) if self.join_type == self.AND: modifier &= node_modifier else: modifier |= node_modifier if self._is_negated: modifier = ~modifier return modifier
[docs] def resolve( self, resolve_context: ResolveContext, ) -> QueryModifier: """ Resolves the logical Q chain into the parts of a SQL statement. :param model: The Model this Q Expression should be resolved on. :param table: ``pypika.Table`` to keep track of the virtual SQL table (to allow self referential joins) """ if self.filters: return self._resolve_kwargs(resolve_context) return self._resolve_children(resolve_context)
[docs]class Function(Expression): """ Function/Aggregate base. :param field: Field name :param default_values: Extra parameters to the function. .. attribute:: database_func :annotation: pypika.terms.Function The pypika function this represents. .. attribute:: populate_field_object :annotation: bool = False Enable populate_field_object where we want to try and preserve the field type. """ __slots__ = ("field", "field_object", "default_values") database_func: Type[PypikaFunction] = PypikaFunction # Enable populate_field_object where we want to try and preserve the field type. populate_field_object = False def __init__( self, field: str | F | CombinedExpression | "Function", *default_values: Any ) -> None: self.field = field self.field_object: "Field | None" = None self.default_values = default_values def _get_function_field(self, field: Term | str, *default_values) -> PypikaFunction: return self.database_func(field, *default_values) # type:ignore[arg-type] def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult: term, joins, output_field = resolve_nested_field( resolve_context.model, resolve_context.table, field ) if self.populate_field_object: self.field_object = output_field return ResolveResult(term=term, joins=joins, output_field=output_field) def _resolve_default_values(self, resolve_context: ResolveContext) -> Iterator[Any]: for default_value in self.default_values: if isinstance(default_value, Function): yield default_value.resolve(resolve_context).term else: yield default_value
[docs] def resolve(self, resolve_context: ResolveContext) -> ResolveResult: """ Used to resolve the Function statement for SQL generation. :param model: Model the function is applied on to. :param table: ``pypika.Table`` to keep track of the virtual SQL table (to allow self referential joins) :return: Dict with keys ``"joins"`` and ``"fields"`` """ default_values = self._resolve_default_values(resolve_context) function_arg = ( self._resolve_nested_field(resolve_context, self.field) if isinstance(self.field, str) else self.field.resolve(resolve_context) ) term = self._get_function_field(function_arg.term, *default_values) res = ResolveResult( term=term, joins=function_arg.joins, output_field=function_arg.output_field, # type:ignore[call-overload] ) if self.populate_field_object and ( res_output_field := res.output_field # type:ignore[call-overload] ): self.field_object = res_output_field return res
[docs]class Aggregate(Function): """ Base for SQL Aggregates. :param field: Field name :param default_values: Extra parameters to the function. :param is_distinct: Flag for aggregate with distinction """ database_func: Type[AggregateFunction] = DistinctOptionFunction def __init__( self, field: str | F | CombinedExpression, *default_values: Any, distinct: bool = False, _filter: Q | None = None, ) -> None: super().__init__(field, *default_values) self.distinct = distinct self.filter = _filter def _get_function_field( # type:ignore[override] self, field: ArithmeticExpression | PypikaField | str, *default_values ) -> DistinctOptionFunction: function = cast(DistinctOptionFunction, self.database_func(field, *default_values)) if self.distinct: function = function.distinct() return function def _resolve_nested_field(self, resolve_context: ResolveContext, field: str) -> ResolveResult: ret = super()._resolve_nested_field(resolve_context, field) if self.filter: modifier = QueryModifier() modifier &= self.filter.resolve(resolve_context) ret.term = PypikaCase().when(modifier.where_criterion, ret.term).else_(None) return ret
class _WhenThen(Term): """This is not a real term, but a helper to store the when and then terms.""" def __init__(self, when: Term, then: Term) -> None: self.when = when self.then = then
[docs]class When(Expression): """ When expression. :param args: Q objects :param kwargs: keyword criterion like filter :param then: value for criterion :param negate: false (default) """ def __init__( self, *args: Q, then: str | F | CombinedExpression | Function, negate: bool = False, **kwargs: Any, ) -> None: self.args = args self.then = then self.negate = negate self.kwargs = kwargs def _resolve_q_objects(self) -> list[Q]: q_objects = [] for arg in self.args: if not isinstance(arg, Q): raise TypeError("expected Q objects as args") if self.negate: q_objects.append(~arg) else: q_objects.append(arg) for key, value in self.kwargs.items(): if self.negate: q_objects.append(~Q(**{key: value})) else: q_objects.append(Q(**{key: value})) return q_objects def resolve(self, resolve_context: ResolveContext) -> ResolveResult: q_objects = self._resolve_q_objects() modifier = QueryModifier() for node in q_objects: modifier &= node.resolve(resolve_context) if isinstance(self.then, Expression): then = self.then.resolve(resolve_context).term else: then = cast(Term, Term.wrap_constant(self.then)) return ResolveResult(term=_WhenThen(modifier.where_criterion, then))
[docs]class Case(Expression): """ Case expression. :param args: When objects :param default: value for 'CASE WHEN ... THEN ... ELSE <default> END' """ def __init__( self, *args: When, default: str | F | CombinedExpression | Function | None = None, ) -> None: self.args = args self.default = default def resolve(self, resolve_context: ResolveContext) -> ResolveResult: case = PypikaCase() for arg in self.args: if not isinstance(arg, When): raise TypeError("expected When objects as args") when = arg.resolve(resolve_context) when_term = cast(_WhenThen, when.term) case = case.when(when_term.when, when_term.then) if isinstance(self.default, Expression): case = case.else_(self.default.resolve(resolve_context).term) else: case = case.else_(Term.wrap_constant(self.default)) return ResolveResult(term=case)