from__future__importannotationsimportabcimportasynciofromcollections.abcimportSequencefromtypingimportAny,Generic,TypeVar,castfrompypika_tortoiseimportQueryfromtortoise.backends.base.executorimportBaseExecutorfromtortoise.backends.base.schema_generatorimportBaseSchemaGeneratorfromtortoise.connectionimportconnectionsfromtortoise.exceptionsimportTransactionManagementErrorfromtortoise.logimportdb_client_loggerT_conn=TypeVar("T_conn")# Instance of client connection, such as: asyncpg.Connection()
[docs]classCapabilities:""" DB Client Capabilities indicates the supported feature-set, and is also used to note common workarounds to deficiencies. Defaults are set with the following standard: * Deficiencies: assume it is working right. * Features: assume it doesn't have it. :param dialect: Dialect name of the DB Client driver. :param daemon: Is the DB an external Daemon we connect to? :param requires_limit: Indicates that this DB requires a ``LIMIT`` statement for an ``OFFSET`` statement to work. :param inline_comment: Indicates that comments should be rendered in line with the DDL statement, and not as a separate statement. :param supports_transactions: Indicates that this DB supports transactions. :param support_for_update: Indicates that this DB supports SELECT ... FOR UPDATE SQL statement. :param support_index_hint: Support force index or use index. :param support_update_limit_order_by: support update/delete with limit and order by. :param: support_for_posix_regex_queries: indicated if the db supports posix regex queries """def__init__(self,dialect:str,*,# Is the connection a Daemon?daemon:bool=True,# Deficiencies to work around:requires_limit:bool=False,inline_comment:bool=False,supports_transactions:bool=True,support_for_update:bool=True,# Support force index or use index?support_index_hint:bool=False,# support update/delete with limit and order bysupport_update_limit_order_by:bool=True,support_for_posix_regex_queries:bool=False,)->None:super().__setattr__("_mutable",True)self.dialect=dialectself.daemon=daemonself.requires_limit=requires_limitself.inline_comment=inline_commentself.supports_transactions=supports_transactionsself.support_for_update=support_for_updateself.support_index_hint=support_index_hintself.support_update_limit_order_by=support_update_limit_order_byself.support_for_posix_regex_queries=support_for_posix_regex_queriessuper().__setattr__("_mutable",False)def__setattr__(self,attr:str,value:Any)->None:ifnotgetattr(self,"_mutable",False):raiseAttributeError(attr)super().__setattr__(attr,value)def__str__(self)->str:returnstr(self.__dict__)
[docs]classBaseDBAsyncClient(abc.ABC):""" Base class for containing a DB connection. Parameters get passed as kwargs, and is mostly driver specific. .. attribute:: query_class :annotation: type[pypika_tortoise.Query] The PyPika Query dialect (low level dialect) .. attribute:: executor_class :annotation: type[BaseExecutor] The executor dialect class (high level dialect) .. attribute:: schema_generator :annotation: type[BaseSchemaGenerator] The DDL schema generator .. attribute:: capabilities :annotation: Capabilities Contains the connection capabilities """_connection:Any_parent:BaseDBAsyncClient_pool:Anyconnection_name:strquery_class:type[Query]=Queryexecutor_class:type[BaseExecutor]=BaseExecutorschema_generator:type[BaseSchemaGenerator]=BaseSchemaGeneratorcapabilities:Capabilities=Capabilities("")def__init__(self,connection_name:str,fetch_inserted:bool=True,**kwargs:Any)->None:self.log=db_client_loggerself.connection_name=connection_nameself.fetch_inserted=fetch_inserted
[docs]asyncdefcreate_connection(self,with_db:bool)->None:""" Establish a DB connection. :param with_db: If True, then select the DB to use, else use default. Use case for this is to create/drop a database. """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefclose(self)->None:""" Closes the DB connection. """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefdb_create(self)->None:""" Created the database in the server. Typically only called by the test runner. Need to have called ``create_connection()``` with parameter ``with_db=False`` set to use the default connection instead of the configured one, else you would get errors indicating the database doesn't exist. """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefdb_delete(self)->None:""" Delete the database from the Server. Typically only called by the test runner. Need to have called ``create_connection()``` with parameter ``with_db=False`` set to use the default connection instead of the configured one, else you would get errors indicating the database is in use. """raiseNotImplementedError()# pragma: nocoverage
[docs]defacquire_connection(self)->ConnectionWrapper|PoolConnectionWrapper:""" Acquires a connection from the pool. Will return the current context connection if already in a transaction. """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefexecute_insert(self,query:str,values:list)->Any:""" Executes a RAW SQL insert statement, with provided parameters. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. :return: The primary key if it is generated by the DB. (Currently only integer autonumber PK's) """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefexecute_query(self,query:str,values:list|None=None)->tuple[int,Sequence[dict]]:""" Executes a RAW SQL query statement, and returns the resultset. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. :return: A tuple of: (The number of rows affected, The resultset) """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefexecute_script(self,query:str)->None:""" Executes a RAW SQL script with multiple statements, and returns nothing. :param query: The SQL string, which will be passed on verbatim. Semicolons is supported here. """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefexecute_many(self,query:str,values:list[list])->None:""" Executes a RAW bulk insert statement, like execute_insert, but returns no data. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. """raiseNotImplementedError()# pragma: nocoverage
[docs]asyncdefexecute_query_dict(self,query:str,values:list|None=None)->list[dict]:""" Executes a RAW SQL query statement, and returns the resultset as a list of dicts. :param query: The SQL string, pre-parametrized for the target DB dialect. :param values: A sequence of positional DB parameters. """raiseNotImplementedError()# pragma: nocoverage
classTransactionalDBClient(BaseDBAsyncClient,abc.ABC):"""An interface of the DB client that supports transactions."""_finalized:bool=False@abc.abstractmethodasyncdefbegin(self)->None:...@abc.abstractmethodasyncdefsavepoint(self)->None:...@abc.abstractmethodasyncdefrollback(self)->None:...@abc.abstractmethodasyncdefsavepoint_rollback(self)->None:...@abc.abstractmethodasyncdefcommit(self)->None:...@abc.abstractmethodasyncdefrelease_savepoint(self)->None:...classConnectionWrapper(Generic[T_conn]):"""Wraps the connections with a lock to facilitate safe concurrent access when using asyncio.gather, TaskGroup, or similar."""__slots__=("connection","_lock","client")def__init__(self,lock:asyncio.Lock,client:BaseDBAsyncClient)->None:self._lock:asyncio.Lock=lockself.client=clientself.connection:T_conn=client._connectionasyncdefensure_connection(self)->None:ifnotself.connection:awaitself.client.create_connection(with_db=True)self.connection=self.client._connectionasyncdef__aenter__(self)->T_conn:awaitself._lock.acquire()awaitself.ensure_connection()returnself.connectionasyncdef__aexit__(self,exc_type:Any,exc_val:Any,exc_tb:Any)->None:self._lock.release()classTransactionContext(Generic[T_conn]):"""A context manager interface for transactions. It is returned from in_transaction and _in_transaction."""client:TransactionalDBClient@abc.abstractmethodasyncdef__aenter__(self)->T_conn:...@abc.abstractmethodasyncdef__aexit__(self,exc_type:Any,exc_val:Any,exc_tb:Any)->None:...classTransactionContextPooled(TransactionContext):"A version of TransactionContext that uses a pool to acquire connections."__slots__=("client","connection_name","token","_pool_init_lock")def__init__(self,client:TransactionalDBClient,pool_init_lock:asyncio.Lock)->None:self.client=clientself.connection_name=client.connection_nameself._pool_init_lock=pool_init_lockasyncdefensure_connection(self)->None:ifnotself.client._parent._pool:# a safeguard against multiple concurrent tasks trying to initialize the poolasyncwithself._pool_init_lock:ifnotself.client._parent._pool:awaitself.client._parent.create_connection(with_db=True)asyncdef__aenter__(self)->TransactionalDBClient:awaitself.ensure_connection()# Set the context variable so the current task is always seeing a# TransactionWrapper conneciton.self.token=connections.set(self.connection_name,self.client)self.client._connection=awaitself.client._parent._pool.acquire()awaitself.client.begin()returnself.clientasyncdef__aexit__(self,exc_type:Any,exc_val:Any,exc_tb:Any)->None:try:ifnotself.client._finalized:ifexc_type:# Can't rollback a transaction that already failed.ifexc_typeisnotTransactionManagementError:awaitself.client.rollback()else:awaitself.client.commit()finally:ifself.client._parent._pool:awaitself.client._parent._pool.release(self.client._connection)connections.reset(self.token)classNestedTransactionContext(TransactionContext):__slots__=("client","connection_name")def__init__(self,client:TransactionalDBClient)->None:self.client=clientself.connection_name=client.connection_nameasyncdef__aenter__(self)->TransactionalDBClient:awaitself.client.savepoint()returnself.clientasyncdef__aexit__(self,exc_type:Any,exc_val:Any,exc_tb:Any)->None:ifnotself.client._finalized:ifexc_type:# Can't rollback a transaction that already failed.ifexc_typeisnotTransactionManagementError:awaitself.client.savepoint_rollback()else:awaitself.client.release_savepoint()classPoolConnectionWrapper(Generic[T_conn]):"""Class to manage acquiring from and releasing connections to a pool."""__slots__=("client","connection","_pool_init_lock")def__init__(self,client:BaseDBAsyncClient,pool_init_lock:asyncio.Lock)->None:self.client=clientself.connection:T_conn|None=Noneself._pool_init_lock=pool_init_lockasyncdefensure_connection(self)->None:ifnotself.client._pool:# a safeguard against multiple concurrent tasks trying to initialize the poolasyncwithself._pool_init_lock:ifnotself.client._pool:awaitself.client.create_connection(with_db=True)asyncdef__aenter__(self)->T_conn:awaitself.ensure_connection()# get first available connection. If none available, wait until one is releasedself.connection=awaitself.client._pool.acquire()returncast(T_conn,self.connection)asyncdef__aexit__(self,exc_type:Any,exc_val:Any,exc_tb:Any)->None:# release the connection back to the poolawaitself.client._pool.release(self.connection)