Source code for tortoise.contrib.postgres.fields

from __future__ import annotations

from collections.abc import Sequence
from typing import Any

from tortoise.exceptions import ConfigurationError
from tortoise.fields import Field


[docs] class TSVectorField(Field): SQL_TYPE = "TSVECTOR" allows_generated = True def __init__( self, source_fields: Sequence[str] | str | None = None, config: str | None = None, weights: Sequence[str] | None = None, stored: bool = True, **kwargs: Any, ) -> None: if isinstance(source_fields, str): source_fields = (source_fields,) self.source_fields = tuple(source_fields or ()) if not self.source_fields and stored: stored = False if "generated" in kwargs and kwargs["generated"] != stored: raise ConfigurationError("TSVectorField 'generated' must match 'stored' when provided.") generated = kwargs.pop("generated", stored) if generated and not self.source_fields: raise ConfigurationError("TSVectorField generated columns require source_fields.") super().__init__(generated=generated, **kwargs) self.config = config self.weights = tuple(weights) if weights is not None else None self.stored = stored if self.weights and not self.source_fields: raise ConfigurationError("TSVectorField weights require source_fields.") if self.weights and len(self.weights) != len(self.source_fields): raise ConfigurationError("TSVectorField weights must match source_fields length.") def _quote_sql_literal(self, value: str) -> str: escaped = value.replace("'", "''") return f"'{escaped}'" def _to_tsvector_sql(self, db_field: str) -> str: field_sql = f"COALESCE(\"{db_field}\", '')" if self.config is not None: return f"TO_TSVECTOR({self._quote_sql_literal(self.config)},{field_sql})" return f"TO_TSVECTOR({field_sql})" def _get_generated_sql(self) -> str | None: if not self.stored: return None parts: list[str] = [] for idx, field_name in enumerate(self.source_fields): field = self.model._meta.fields_map.get(field_name) if field is None: raise ConfigurationError(f"Unknown source field '{field_name}'.") if not field.has_db_field: raise ConfigurationError( f"Source field '{field_name}' does not map to a database column." ) db_field = field.source_field or field.model_field_name vector_sql = self._to_tsvector_sql(db_field) if self.weights is not None: weight = self._quote_sql_literal(self.weights[idx]) vector_sql = f"SETWEIGHT({vector_sql},{weight})" parts.append(vector_sql) expression = " || ".join(parts) return f"GENERATED ALWAYS AS ({expression}) STORED"
[docs] def describe(self, serializable: bool) -> dict: desc = super().describe(serializable) desc["source_fields"] = list(self.source_fields) if serializable else self.source_fields desc["config"] = self.config if self.weights is None: desc["weights"] = None else: desc["weights"] = list(self.weights) if serializable else self.weights desc["stored"] = self.stored return desc
class _db_postgres: def __init__(self, field: TSVectorField) -> None: self.field = field @property def GENERATED_SQL(self) -> str | None: return self.field._get_generated_sql()
[docs] class ArrayField(Field, list): # type: ignore def __init__(self, element_type: str = "int", **kwargs: Any): super().__init__(**kwargs) self.element_type = element_type.upper() @property def SQL_TYPE(self) -> str: # type: ignore return f"{self.element_type}[]"