Source code for tortoise.connection
from __future__ import annotations
import asyncio
import contextvars
import importlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
from tortoise.backends.base.config_generator import expand_db_url
from tortoise.exceptions import ConfigurationError
if TYPE_CHECKING:
from tortoise.backends.base.client import BaseDBAsyncClient
DBConfigType = dict[str, Any]
@dataclass(slots=True)
class ConnectionToken:
"""
Token for resetting connection storage modifications.
Used by transactions to temporarily replace a connection with a transaction client,
then restore the original connection when the transaction completes.
"""
_handler: ConnectionHandler
_alias: str
_old_value: BaseDBAsyncClient | None
_cv_token: contextvars.Token | None = field(default=None)
_used: bool = field(default=False)
[docs]
class ConnectionHandler:
"""
Connection management for a single TortoiseContext.
Each TortoiseContext owns its own ConnectionHandler instance with isolated storage.
"""
def __init__(self) -> None:
"""Initialize connection handler with empty storage."""
self._db_config: DBConfigType | None = None
self._create_db: bool = False
# Use ContextVar for task isolation within this handler instance.
# This ensures transactions (which use .set()) are isolated to the task.
self._storage_var: contextvars.ContextVar[dict[str, BaseDBAsyncClient]] = (
contextvars.ContextVar(f"storage_{id(self)}", default={})
)
@property
def _storage(self) -> dict[str, BaseDBAsyncClient]:
"""
Internal storage for connections.
We use a property to provide a dict-like interface while being backed by a ContextVar.
"""
return self._get_storage()
@_storage.setter
def _storage(self, value: dict[str, BaseDBAsyncClient]) -> None:
"""Allow direct assignment to storage for legacy compatibility (and tests)."""
self._storage_var.set(value)
def _get_storage(self) -> dict[str, BaseDBAsyncClient]:
"""Get the connection storage dict for the current task context."""
return self._storage_var.get()
def _set_storage(self, new_storage: dict[str, BaseDBAsyncClient]) -> None:
"""Set the connection storage dict. Used for testing purposes."""
self._storage = new_storage
def _copy_storage(self) -> dict[str, BaseDBAsyncClient]:
"""Return a shallow copy of the storage."""
return dict(self._get_storage())
def _clear_storage(self) -> None:
"""Clear all connections from storage in the current context."""
self._storage_var.set({})
async def _init(self, db_config: DBConfigType, create_db: bool) -> None:
if self._db_config is None:
self._db_config = db_config
else:
self._db_config.update(db_config)
self._create_db = create_db
await self._init_connections()
def _init_config(self, db_config: DBConfigType, create_db: bool = False) -> None:
if self._db_config is None:
self._db_config = db_config
else:
self._db_config.update(db_config)
self._create_db = create_db
@property
def db_config(self) -> DBConfigType:
"""
Return the DB config.
This is the same config passed to the
:meth:`Tortoise.init<tortoise.Tortoise.init>` method while initialization.
:raises ConfigurationError:
If this property is accessed before calling the
:meth:`Tortoise.init<tortoise.Tortoise.init>` method.
"""
if self._db_config is None:
raise ConfigurationError(
"DB configuration not initialised. Make sure to call "
"Tortoise.init with a valid configuration before attempting "
"to create connections."
)
return self._db_config
def _discover_client_class(self, db_info: dict) -> type[BaseDBAsyncClient]:
# Let exception bubble up for transparency
engine_str = db_info.get("engine", "")
engine_module = importlib.import_module(engine_str)
try:
if hasattr(engine_module, "get_client_class"):
client_class = engine_module.get_client_class(db_info)
else:
client_class = engine_module.client_class
except AttributeError:
raise ConfigurationError(
f'Backend for engine "{engine_str}" does not implement db client'
)
return client_class
def _get_db_info(self, conn_alias: str) -> str | dict:
try:
return self.db_config[conn_alias]
except KeyError:
raise ConfigurationError(
f"Unable to get db settings for alias '{conn_alias}'. Please "
f"check if the config dict contains this alias and try again"
)
async def _init_connections(self) -> None:
for alias in self.db_config:
connection: BaseDBAsyncClient = self.get(alias)
if self._create_db:
await connection.db_create()
def _create_connection(self, conn_alias: str) -> BaseDBAsyncClient:
db_info = self._get_db_info(conn_alias)
if isinstance(db_info, str):
db_info = expand_db_url(db_info)
client_class = self._discover_client_class(db_info)
db_params = db_info["credentials"].copy()
db_params.update({"connection_name": conn_alias})
connection: BaseDBAsyncClient = client_class(**db_params)
return connection
[docs]
def get(self, conn_alias: str) -> BaseDBAsyncClient:
"""
Return the connection object for the given alias, creating it if needed.
Used for accessing the low-level connection object
(:class:`BaseDBAsyncClient<tortoise.backends.base.client.BaseDBAsyncClient>`) for the
given alias.
:param conn_alias: The alias for which the connection has to be fetched
:raises ConfigurationError: If the connection alias does not exist.
"""
storage = self._get_storage()
try:
return storage[conn_alias]
except KeyError:
connection: BaseDBAsyncClient = self._create_connection(conn_alias)
storage[conn_alias] = connection
return connection
[docs]
def set(self, conn_alias: str, conn_obj: BaseDBAsyncClient) -> ConnectionToken:
"""
Sets the given alias to the provided connection object for the current task.
:param conn_alias: The alias to set the connection for.
:param conn_obj: The connection object that needs to be set for this alias.
:returns: A token that can be used to restore the previous context via reset().
.. note::
This method is primarily used by transactions to temporarily replace a connection
with a transaction client. Call reset() with the returned token to restore the
original connection when the transaction completes.
"""
old_value = self._get_storage().get(conn_alias)
storage_copy = self._copy_storage()
storage_copy[conn_alias] = conn_obj
cv_token = self._storage_var.set(storage_copy)
return ConnectionToken(
_handler=self, _alias=conn_alias, _old_value=old_value, _cv_token=cv_token
)
[docs]
def discard(self, conn_alias: str) -> BaseDBAsyncClient | None:
"""
Discards the given alias from the storage in the `current context`.
:param conn_alias: The alias for which the connection object should be discarded.
.. important::
Make sure to have called ``conn.close()`` for the provided alias before calling
this method else there would be a connection leak (dangling connection).
"""
return self._get_storage().pop(conn_alias, None)
[docs]
def reset(self, token: ConnectionToken | None) -> None:
"""
Reset the connection storage to the previous context state.
Restores the connection state for all aliases to what it was before the set() call.
:param token:
The token returned by the set() method. Can be None (no-op).
"""
if token is None:
return
if token._used:
raise ValueError("Token has already been used")
token._used = True
if token._cv_token and isinstance(token._cv_token, contextvars.Token):
self._storage_var.reset(token._cv_token)
else:
# Fallback when no ContextVar token (e.g., mock tokens in tests)
storage = self._copy_storage()
if token._old_value is None:
storage.pop(token._alias, None)
else:
storage[token._alias] = token._old_value
self._storage = storage
[docs]
def all(self) -> list[BaseDBAsyncClient]:
"""Returns a list of connection objects from the storage in the `current context`."""
# The reason this method iterates over db_config and not over `storage` directly is
# because: assume that someone calls `discard` with a certain alias, and calls this
# method subsequently. The alias which just got discarded from the storage would not
# appear in the returned list though it exists as part of the `db_config`.
return [self.get(alias) for alias in self.db_config]
[docs]
async def close_all(self, discard: bool = True) -> None:
"""
Closes all connections in the storage in the `current context`.
All closed connections will be removed from the storage by default.
:param discard:
If ``False``, all connection objects are closed but `retained` in the storage.
"""
# Handle case where connections were never initialized (e.g., init failed)
if self._db_config is None:
return
tasks = [conn.close() for conn in self.all()]
await asyncio.gather(*tasks)
if discard:
for alias in self.db_config:
self.discard(alias)
class _ConnectionsProxy:
"""
Simple delegator that forwards all operations to the current context's ConnectionHandler.
This provides backward compatibility for code using the `connections` module-level singleton.
All operations require an active TortoiseContext - if no context is active, a clear error is raised.
.. deprecated::
Direct use of `connections` is deprecated. Use `get_connection()` or `get_connections()` instead,
or access connections through the context: `ctx.connections`.
"""
def _get_handler(self) -> ConnectionHandler:
"""Get the ConnectionHandler from the current context."""
from tortoise.context import require_context
return require_context().connections
def __getattr__(self, name: str):
"""Delegate attribute access to the current context's ConnectionHandler."""
return getattr(self._get_handler(), name)
# Properties must be explicit since __getattr__ doesn't intercept descriptor access
@property
def db_config(self) -> DBConfigType:
"""Return the DB config."""
return self._get_handler().db_config
connections = _ConnectionsProxy()
[docs]
def get_connection(alias: str) -> BaseDBAsyncClient:
"""
Get a database connection by alias from the current context.
This is a convenience function. Prefer accessing connections directly
via context: `ctx.connections.get(alias)`
:param alias: The connection alias (e.g., "default")
:raises ConfigurationError: If no context is active or connection not found
"""
from tortoise.context import require_context
return require_context().connections.get(alias)
[docs]
def get_connections() -> ConnectionHandler:
"""
Get the ConnectionHandler from the current context.
This is a convenience function. Prefer accessing connections directly
via context: `ctx.connections`
:raises ConfigurationError: If no context is active
"""
from tortoise.context import require_context
return require_context().connections