Source code for tortoise.contrib.postgres.search

from __future__ import annotations

from collections.abc import Iterable, Sequence
from itertools import chain

from pypika_tortoise.enums import Comparator
from pypika_tortoise.terms import BasicCriterion, Term, ValueWrapper
from pypika_tortoise.terms import Function as PypikaFunction

from tortoise.contrib.postgres.fields import TSVectorField
from tortoise.contrib.postgres.functions import PlainToTsQuery, ToTsVector
from tortoise.expressions import Expression, F, ResolveContext, ResolveResult, Value
from tortoise.fields import FloatField, TextField
from tortoise.query_utils import TableCriterionTuple

ScalarValue = str | int | float | bool
VectorInput = Expression | Term | str
QueryInput = Expression | Term | str
ConfigInput = Expression | Term | str
WeightInput = Expression | Term | str
RankWeightInput = Expression | Term | Sequence[float] | Sequence[int] | str
NormalizationInput = Expression | Term | int
HeadlineExpressionInput = Expression | Term | str
HeadlineOptionValue = str | int | bool


class Comp(Comparator):
    search = " @@ "


[docs] class SearchCriterion(BasicCriterion): def __init__(self, field: Term, expr: Term | str, vectorize: bool = True) -> None: vector = ToTsVector(field) if vectorize else field query = expr if isinstance(expr, Term) else PlainToTsQuery(ValueWrapper(expr)) super().__init__(Comp.search, vector, query)
class _TsInfixOperator(Term): def __init__(self, left: Term, operator: str, right: Term) -> None: super().__init__() self.left = left self.operator = operator self.right = right @property def is_aggregate(self) -> bool | None: # type:ignore[override] return self.left.is_aggregate or self.right.is_aggregate def get_sql(self, ctx) -> str: left_sql = self.left.get_sql(ctx) right_sql = self.right.get_sql(ctx) sql = f"({left_sql}{self.operator}{right_sql})" if ctx.with_alias and self.alias: # pragma: nocoverage return f'{sql} "{self.alias}"' return sql class _TsQueryInvert(Term): def __init__(self, term: Term) -> None: super().__init__() self.term = term @property def is_aggregate(self) -> bool | None: # type:ignore[override] return self.term.is_aggregate def get_sql(self, ctx) -> str: sql = f"!!({self.term.get_sql(ctx)})" if ctx.with_alias and self.alias: # pragma: nocoverage return f'{sql} "{self.alias}"' return sql def _merge_joins(*joins: Iterable[TableCriterionTuple]) -> list[TableCriterionTuple]: return list(set(chain.from_iterable(joins))) def _resolve_expression( value: Expression | Term | ScalarValue | Sequence[float] | Sequence[int] | str | None, resolve_context: ResolveContext, *, treat_str_as_field: bool, ) -> ResolveResult: if isinstance(value, Expression): return value.resolve(resolve_context) if isinstance(value, Term): return ResolveResult(term=value) if isinstance(value, str) and treat_str_as_field: return F(value).resolve(resolve_context) return Value(value).resolve(resolve_context) class SearchVectorCombinable(Expression): def _combine(self, other: SearchVectorCombinable, reversed: bool) -> CombinedSearchVector: if not isinstance(other, SearchVectorCombinable): raise TypeError( "SearchVector can only be combined with other SearchVector instances, " f"got {other.__class__.__name__}." ) if reversed: return CombinedSearchVector(other, self) return CombinedSearchVector(self, other) def __add__(self, other: SearchVectorCombinable) -> CombinedSearchVector: return self._combine(other, False) def __radd__(self, other: SearchVectorCombinable) -> CombinedSearchVector: return self._combine(other, True)
[docs] class SearchVector(SearchVectorCombinable, Expression): def __init__( self, *expressions: VectorInput, config: ConfigInput | None = None, weight: WeightInput | None = None, ): if not expressions: raise ValueError("SearchVector requires at least one expression.") self.expressions = expressions self.config = config self.weight = weight def resolve(self, resolve_context: ResolveContext) -> ResolveResult: resolved = [ _resolve_expression(expr, resolve_context, treat_str_as_field=True) for expr in self.expressions ] terms = [item.term for item in resolved] joins = _merge_joins(*(item.joins for item in resolved)) combined = terms[0] for term in terms[1:]: combined = _TsInfixOperator(combined, " || ", ValueWrapper(" ")) combined = _TsInfixOperator(combined, " || ", term) args = [combined] if self.config is not None: config_resolved = _resolve_expression( self.config, resolve_context, treat_str_as_field=False ) args = [config_resolved.term, combined] joins = _merge_joins(joins, config_resolved.joins) vector_term = PypikaFunction("TO_TSVECTOR", *args) if self.weight is not None: weight_resolved = _resolve_expression( self.weight, resolve_context, treat_str_as_field=False ) vector_term = PypikaFunction( "SETWEIGHT", vector_term, weight_resolved.term, ) joins = _merge_joins(joins, weight_resolved.joins) return ResolveResult(term=vector_term, joins=joins, output_field=TSVectorField())
class CombinedSearchVector(SearchVectorCombinable, Expression): def __init__(self, left: SearchVectorCombinable, right: SearchVectorCombinable) -> None: self.left = left self.right = right def resolve(self, resolve_context: ResolveContext) -> ResolveResult: left = self.left.resolve(resolve_context) right = self.right.resolve(resolve_context) term = _TsInfixOperator(left.term, " || ", right.term) return ResolveResult( term=term, joins=_merge_joins(left.joins, right.joins), output_field=TSVectorField(), ) class SearchQueryCombinable(Expression): def _combine( self, other: SearchQueryCombinable, operator: str, reversed: bool ) -> CombinedSearchQuery: if not isinstance(other, SearchQueryCombinable): raise TypeError( "SearchQuery can only be combined with other SearchQuery instances, " f"got {other.__class__.__name__}." ) if reversed: return CombinedSearchQuery(other, operator, self) return CombinedSearchQuery(self, operator, other) def __or__(self, other: SearchQueryCombinable) -> CombinedSearchQuery: return self._combine(other, " || ", False) def __ror__(self, other: SearchQueryCombinable) -> CombinedSearchQuery: return self._combine(other, " || ", True) def __and__(self, other: SearchQueryCombinable) -> CombinedSearchQuery: return self._combine(other, " && ", False) def __rand__(self, other: SearchQueryCombinable) -> CombinedSearchQuery: return self._combine(other, " && ", True)
[docs] class SearchQuery(SearchQueryCombinable, Expression): SEARCH_TYPES = { "plain": "PLAINTO_TSQUERY", "phrase": "PHRASETO_TSQUERY", "raw": "TO_TSQUERY", "websearch": "WEBSEARCH_TO_TSQUERY", } def __init__( self, value: QueryInput, config: ConfigInput | None = None, search_type: str = "plain", invert: bool = False, ) -> None: if isinstance(value, LexemeCombinable): search_type = "raw" function = self.SEARCH_TYPES.get(search_type) if function is None: raise ValueError(f"Unknown search_type argument '{search_type}'.") self.function = function self.value = value self.config = config self.invert = invert def resolve(self, resolve_context: ResolveContext) -> ResolveResult: value_result = _resolve_expression(self.value, resolve_context, treat_str_as_field=False) joins = value_result.joins args = [value_result.term] if self.config is not None: config_result = _resolve_expression( self.config, resolve_context, treat_str_as_field=False ) args = [config_result.term, value_result.term] joins = _merge_joins(joins, config_result.joins) term: Term = PypikaFunction(self.function, *args) if self.invert: term = _TsQueryInvert(term) return ResolveResult(term=term, joins=joins) def __invert__(self) -> SearchQuery: return SearchQuery( self.value, config=self.config, search_type="raw" if isinstance(self.value, LexemeCombinable) else self._search_type, invert=not self.invert, ) @property def _search_type(self) -> str: for key, value in self.SEARCH_TYPES.items(): if value == self.function: return key return "plain"
class CombinedSearchQuery(SearchQueryCombinable, Expression): def __init__( self, left: SearchQueryCombinable, operator: str, right: SearchQueryCombinable ) -> None: self.left = left self.right = right self.operator = operator def resolve(self, resolve_context: ResolveContext) -> ResolveResult: left = self.left.resolve(resolve_context) right = self.right.resolve(resolve_context) term = _TsInfixOperator(left.term, self.operator, right.term) return ResolveResult(term=term, joins=_merge_joins(left.joins, right.joins))
[docs] class SearchRank(Expression): def __init__( self, vector: VectorInput, query: QueryInput, weights: RankWeightInput | None = None, normalization: NormalizationInput | None = None, cover_density: bool = False, ) -> None: self.vector = vector self.query = query self.weights = weights self.normalization = normalization self.cover_density = cover_density def resolve(self, resolve_context: ResolveContext) -> ResolveResult: vector_expr = ( self.vector if isinstance(self.vector, (Expression, Term)) else SearchVector(self.vector) ) query_expr = ( self.query if isinstance(self.query, (Expression, Term)) else SearchQuery(self.query) ) vector_result = _resolve_expression(vector_expr, resolve_context, treat_str_as_field=False) query_result = _resolve_expression(query_expr, resolve_context, treat_str_as_field=False) args = [vector_result.term, query_result.term] joins = _merge_joins(vector_result.joins, query_result.joins) if self.weights is not None: weights_result = _resolve_expression( self.weights, resolve_context, treat_str_as_field=False ) args = [weights_result.term, *args] joins = _merge_joins(joins, weights_result.joins) if self.normalization is not None: normalization_result = _resolve_expression( self.normalization, resolve_context, treat_str_as_field=False ) args.append(normalization_result.term) joins = _merge_joins(joins, normalization_result.joins) function = "TS_RANK_CD" if self.cover_density else "TS_RANK" term = PypikaFunction(function, *args) return ResolveResult(term=term, joins=joins, output_field=FloatField())
def _format_headline_option_value(value: HeadlineOptionValue) -> str: if isinstance(value, bool): return "true" if value else "false" if isinstance(value, str): return "'" + value.replace("'", "''") + "'" return str(value)
[docs] class SearchHeadline(Expression): def __init__( self, expression: HeadlineExpressionInput, query: QueryInput, config: ConfigInput | None = None, start_sel: str | None = None, stop_sel: str | None = None, max_words: int | None = None, min_words: int | None = None, short_word: int | None = None, highlight_all: bool | None = None, max_fragments: int | None = None, fragment_delimiter: str | None = None, ) -> None: self.expression = expression self.query = query self.config = config self.options = { "StartSel": start_sel, "StopSel": stop_sel, "MaxWords": max_words, "MinWords": min_words, "ShortWord": short_word, "HighlightAll": highlight_all, "MaxFragments": max_fragments, "FragmentDelimiter": fragment_delimiter, } def resolve(self, resolve_context: ResolveContext) -> ResolveResult: expression_result = _resolve_expression( self.expression, resolve_context, treat_str_as_field=True ) query_expr = ( self.query if isinstance(self.query, (Expression, Term)) else SearchQuery(self.query) ) query_result = _resolve_expression(query_expr, resolve_context, treat_str_as_field=False) args = [expression_result.term, query_result.term] joins = _merge_joins(expression_result.joins, query_result.joins) if self.config is not None: config_result = _resolve_expression( self.config, resolve_context, treat_str_as_field=False ) args = [config_result.term, *args] joins = _merge_joins(joins, config_result.joins) options = {key: value for key, value in self.options.items() if value is not None} if options: options_sql = ", ".join( f"{key}={_format_headline_option_value(value)}" for key, value in options.items() ) args.append(ValueWrapper(options_sql)) term = PypikaFunction("TS_HEADLINE", *args) return ResolveResult(term=term, joins=joins, output_field=TextField())
class LexemeCombinable(Expression): def _combine(self, other: LexemeCombinable, operator: str, reversed: bool) -> CombinedLexeme: if not isinstance(other, LexemeCombinable): raise TypeError( "A Lexeme can only be combined with another Lexeme, " f"got {other.__class__.__name__}." ) if reversed: return CombinedLexeme(other, operator, self) return CombinedLexeme(self, operator, other) def __or__(self, other: LexemeCombinable) -> CombinedLexeme: return self._combine(other, " | ", False) def __ror__(self, other: LexemeCombinable) -> CombinedLexeme: return self._combine(other, " | ", True) def __and__(self, other: LexemeCombinable) -> CombinedLexeme: return self._combine(other, " & ", False) def __rand__(self, other: LexemeCombinable) -> CombinedLexeme: return self._combine(other, " & ", True) def _as_tsquery(self) -> str: raise NotImplementedError def __invert__(self) -> LexemeCombinable: raise NotImplementedError
[docs] class Lexeme(LexemeCombinable, Expression): def __init__( self, value: str, invert: bool = False, prefix: bool = False, weight: str | None = None, ) -> None: if value == "": raise ValueError("Lexeme value cannot be empty.") if not isinstance(value, str): raise TypeError(f"Lexeme value must be a string, got {value.__class__.__name__}.") if weight is not None and weight.lower() not in {"a", "b", "c", "d"}: raise ValueError(f"Weight must be one of 'A', 'B', 'C', and 'D', got {weight!r}.") self.value = value self.invert = invert self.prefix = prefix self.weight = weight def _as_tsquery(self) -> str: token = "'" + self.value.replace("'", "''") + "'" label = "" if self.prefix: label += "*" if self.weight: label += self.weight if label: token = f"{token}:{label}" if self.invert: token = f"!{token}" return token def resolve(self, resolve_context: ResolveContext) -> ResolveResult: return ResolveResult(term=ValueWrapper(self._as_tsquery())) def __invert__(self) -> Lexeme: return Lexeme( self.value, invert=not self.invert, prefix=self.prefix, weight=self.weight, )
class CombinedLexeme(LexemeCombinable, Expression): def __init__(self, left: LexemeCombinable, operator: str, right: LexemeCombinable) -> None: self.left = left self.right = right self.operator = operator def _as_tsquery(self) -> str: return f"({self.left._as_tsquery()}{self.operator}{self.right._as_tsquery()})" def resolve(self, resolve_context: ResolveContext) -> ResolveResult: return ResolveResult(term=ValueWrapper(self._as_tsquery())) def __invert__(self) -> CombinedLexeme: operator = " & " if self.operator == " | " else " | " return CombinedLexeme(~self.left, operator, ~self.right)