Source code for tortoise.backends.base.client

from __future__ import annotations

import abc
import asyncio
from typing import (
    Any,
    Generic,
    List,
    Optional,
    Sequence,
    Tuple,
    Type,
    TypeVar,
    Union,
    cast,
)

from pypika import Query

from tortoise.backends.base.executor import BaseExecutor
from tortoise.backends.base.schema_generator import BaseSchemaGenerator
from tortoise.connection import connections
from tortoise.exceptions import TransactionManagementError
from tortoise.log import db_client_logger

T_conn = TypeVar("T_conn")  # Instance of client connection, such as: asyncpg.Connection()


[docs]class Capabilities: """ DB Client Capabilities indicates the supported feature-set, and is also used to note common workarounds to deficiencies. Defaults are set with the following standard: * Deficiencies: assume it is working right. * Features: assume it doesn't have it. :param dialect: Dialect name of the DB Client driver. :param daemon: Is the DB an external Daemon we connect to? :param requires_limit: Indicates that this DB requires a ``LIMIT`` statement for an ``OFFSET`` statement to work. :param inline_comment: Indicates that comments should be rendered in line with the DDL statement, and not as a separate statement. :param supports_transactions: Indicates that this DB supports transactions. :param support_for_update: Indicates that this DB supports SELECT ... FOR UPDATE SQL statement. :param support_index_hint: Support force index or use index. :param support_update_limit_order_by: support update/delete with limit and order by. """ def __init__( self, dialect: str, *, # Is the connection a Daemon? daemon: bool = True, # Deficiencies to work around: requires_limit: bool = False, inline_comment: bool = False, supports_transactions: bool = True, support_for_update: bool = True, # Support force index or use index? support_index_hint: bool = False, # support update/delete with limit and order by support_update_limit_order_by: bool = True, ) -> None: super().__setattr__("_mutable", True) self.dialect = dialect self.daemon = daemon self.requires_limit = requires_limit self.inline_comment = inline_comment self.supports_transactions = supports_transactions self.support_for_update = support_for_update self.support_index_hint = support_index_hint self.support_update_limit_order_by = support_update_limit_order_by super().__setattr__("_mutable", False) def __setattr__(self, attr: str, value: Any) -> None: if not getattr(self, "_mutable", False): raise AttributeError(attr) super().__setattr__(attr, value) def __str__(self) -> str: return str(self.__dict__)
[docs]class BaseDBAsyncClient: """ Base class for containing a DB connection. Parameters get passed as kwargs, and is mostly driver specific. .. attribute:: query_class :annotation: Type[pypika.Query] The PyPika Query dialect (low level dialect) .. attribute:: executor_class :annotation: Type[BaseExecutor] The executor dialect class (high level dialect) .. attribute:: schema_generator :annotation: Type[BaseSchemaGenerator] The DDL schema generator .. attribute:: capabilities :annotation: Capabilities Contains the connection capabilities """ query_class: Type[Query] = Query executor_class: Type[BaseExecutor] = BaseExecutor schema_generator: Type[BaseSchemaGenerator] = BaseSchemaGenerator capabilities: Capabilities = Capabilities("") def __init__(self, connection_name: str, fetch_inserted: bool = True, **kwargs: Any) -> None: self.log = db_client_logger self.connection_name = connection_name self.fetch_inserted = fetch_inserted
[docs] async def create_connection(self, with_db: bool) -> None: """ Establish a DB connection. :param with_db: If True, then select the DB to use, else use default. Use case for this is to create/drop a database. """ raise NotImplementedError() # pragma: nocoverage
[docs] async def close(self) -> None: """ Closes the DB connection. """ raise NotImplementedError() # pragma: nocoverage
[docs] async def db_create(self) -> None: """ Created the database in the server. Typically only called by the test runner. Need to have called ``create_connection()``` with parameter ``with_db=False`` set to use the default connection instead of the configured one, else you would get errors indicating the database doesn't exist. """ raise NotImplementedError() # pragma: nocoverage
[docs] async def db_delete(self) -> None: """ Delete the database from the Server. Typically only called by the test runner. Need to have called ``create_connection()``` with parameter ``with_db=False`` set to use the default connection instead of the configured one, else you would get errors indicating the database is in use. """ raise NotImplementedError() # pragma: nocoverage
[docs] def acquire_connection(self) -> Union["ConnectionWrapper", "PoolConnectionWrapper"]: """ Acquires a connection from the pool. Will return the current context connection if already in a transaction. """ raise NotImplementedError() # pragma: nocoverage
def _in_transaction(self) -> "TransactionContext": raise NotImplementedError() # pragma: nocoverage
[docs] async def execute_insert(self, query: str, values: list) -> Any: """ Executes a RAW SQL insert statement, with provided parameters. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. :return: The primary key if it is generated by the DB. (Currently only integer autonumber PK's) """ raise NotImplementedError() # pragma: nocoverage
[docs] async def execute_query( self, query: str, values: Optional[list] = None ) -> Tuple[int, Sequence[dict]]: """ Executes a RAW SQL query statement, and returns the resultset. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. :return: A tuple of: (The number of rows affected, The resultset) """ raise NotImplementedError() # pragma: nocoverage
[docs] async def execute_script(self, query: str) -> None: """ Executes a RAW SQL script with multiple statements, and returns nothing. :param query: The SQL string, which will be passed on verbatim. Semicolons is supported here. """ raise NotImplementedError() # pragma: nocoverage
[docs] async def execute_many(self, query: str, values: List[list]) -> None: """ Executes a RAW bulk insert statement, like execute_insert, but returns no data. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. """ raise NotImplementedError() # pragma: nocoverage
[docs] async def execute_query_dict(self, query: str, values: Optional[list] = None) -> List[dict]: """ Executes a RAW SQL query statement, and returns the resultset as a list of dicts. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. """ raise NotImplementedError() # pragma: nocoverage
class ConnectionWrapper(Generic[T_conn]): """Wraps the connections with a lock to facilitate safe concurrent access when using asyncio.gather, TaskGroup, or similar.""" __slots__ = ("connection", "lock", "client") def __init__(self, lock: asyncio.Lock, client: Any) -> None: self.lock: asyncio.Lock = lock self.client = client self.connection: T_conn = client._connection async def ensure_connection(self) -> None: if not self.connection: await self.client.create_connection(with_db=True) self.connection = self.client._connection async def __aenter__(self) -> T_conn: await self.lock.acquire() await self.ensure_connection() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: self.lock.release() class TransactionContext(Generic[T_conn]): """A context manager interface for transactions. It is returned from in_transaction and _in_transaction.""" connection: T_conn @abc.abstractmethod async def __aenter__(self) -> T_conn: ... @abc.abstractmethod async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: ... class TransactionContextPooled(TransactionContext): "A version of TransactionContext that uses a pool to acquire connections." __slots__ = ("conn_wrapper", "connection", "connection_name", "token") def __init__(self, connection: Any) -> None: self.connection = connection self.connection_name = connection.connection_name async def ensure_connection(self) -> None: if not self.connection._parent._pool: await self.connection._parent.create_connection(with_db=True) async def __aenter__(self) -> T_conn: await self.ensure_connection() # Set the context variable so the current task is always seeing a # TransactionWrapper conneciton. self.token = connections.set(self.connection_name, self.connection) self.connection._connection = await self.connection._parent._pool.acquire() await self.connection.begin() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: try: if not self.connection._finalized: if exc_type: # Can't rollback a transaction that already failed. if exc_type is not TransactionManagementError: await self.connection.rollback() else: await self.connection.commit() finally: if self.connection._parent._pool: await self.connection._parent._pool.release(self.connection._connection) connections.reset(self.token) class NestedTransactionContext(TransactionContext): def __init__(self, connection: Any) -> None: self.connection = connection self.connection_name = connection.connection_name async def __aenter__(self) -> T_conn: await self.connection.savepoint() return self.connection async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: if not self.connection._finalized: if exc_type: # Can't rollback a transaction that already failed. if exc_type is not TransactionManagementError: await self.connection.savepoint_rollback() else: await self.connection.release_savepoint() class PoolConnectionWrapper(Generic[T_conn]): """Class to manage acquiring from and releasing connections to a pool.""" def __init__(self, client: Any) -> None: self.pool = client._pool self.client = client self.connection: Optional[T_conn] = None async def ensure_connection(self) -> None: if not self.pool: await self.client.create_connection(with_db=True) self.pool = self.client._pool async def __aenter__(self) -> T_conn: await self.ensure_connection() # get first available connection. If none available, wait until one is released self.connection = await self.pool.acquire() return cast(T_conn, self.connection) async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: # release the connection back to the pool await self.pool.release(self.connection) class BaseTransactionWrapper: @abc.abstractmethod async def begin(self) -> None: ... @abc.abstractmethod async def savepoint(self) -> None: ... @abc.abstractmethod async def rollback(self) -> None: ... @abc.abstractmethod async def savepoint_rollback(self) -> None: ... @abc.abstractmethod async def commit(self) -> None: ... @abc.abstractmethod async def release_savepoint(self) -> None: ...