Source code for tortoise.contrib.test

"""
Modern testing utilities for Tortoise ORM.

Use tortoise_test_context() with pytest fixtures:

    @pytest_asyncio.fixture
    async def db():
        async with tortoise_test_context(["myapp.models"]) as ctx:
            yield ctx

    @pytest.mark.asyncio
    async def test_example(db):
        user = await User.create(name="Test")
        assert user.id is not None

For capability-based test skipping:

    @requireCapability(dialect="sqlite")
    @pytest.mark.asyncio
    async def test_sqlite_only(db):
        # This test only runs on SQLite
        ...
"""

from __future__ import annotations

import inspect
import typing
from collections.abc import Callable, Coroutine
from functools import partial, wraps
from typing import TYPE_CHECKING, ParamSpec, TypeVar, cast
from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless

from tortoise import Tortoise
from tortoise.connection import get_connection
from tortoise.context import TortoiseContext, tortoise_test_context

if TYPE_CHECKING:
    from tortoise.models import Model

T = TypeVar("T")
P = ParamSpec("P")
AsyncFunc = Callable[P, Coroutine[None, None, T]]
AsyncFuncDeco = Callable[..., AsyncFunc]
ModulesConfigType = str | list[str]
MEMORY_SQLITE = "sqlite://:memory:"

__all__ = (
    "MEMORY_SQLITE",
    "TortoiseContext",
    "tortoise_test_context",
    "requireCapability",
    "truncate_all_models",
    "init_memory_sqlite",
    "SkipTest",
    "expectedFailure",
    "skip",
    "skipIf",
    "skipUnless",
)

expectedFailure.__doc__ = """
Mark test as expecting failure.

On success it will be marked as unexpected success.
"""


[docs] async def truncate_all_models() -> None: """ Truncate all models in the current context. This is a utility function for test cleanup that deletes all rows from all registered model tables. On PostgreSQL, uses ``TRUNCATE ... CASCADE`` for a single fast statement. On other databases, deletes in topological (FK dependency) order so that child rows are removed before parent rows they reference. Raises: ValueError: If Tortoise.apps is not loaded. """ if not Tortoise.apps: raise ValueError("apps are not loaded") models = list(Tortoise.apps.get_models_iterable()) if not models: return db = models[0]._meta.db dialect = db.capabilities.dialect if dialect == "postgres": # PostgreSQL supports TRUNCATE with CASCADE — single statement, fast tables = ", ".join(f'"{m._meta.db_table}"' for m in models) await db.execute_script(f"TRUNCATE {tables} CASCADE") else: # For other dialects, topologically sort by FK dependencies (children first) sorted_models = _topological_sort_models(models) # Disable FK checks to handle self-referential and circular FK constraints if dialect == "mysql": await db.execute_script("SET FOREIGN_KEY_CHECKS = 0") elif dialect == "sqlite": await db.execute_script("PRAGMA foreign_keys = OFF") try: for model in sorted_models: quote_char = model._meta.db.query_class.SQL_CONTEXT.quote_char await model._meta.db.execute_script( f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}" # nosec ) finally: if dialect == "mysql": await db.execute_script("SET FOREIGN_KEY_CHECKS = 1") elif dialect == "sqlite": await db.execute_script("PRAGMA foreign_keys = ON")
def _topological_sort_models(models: list[type[Model]]) -> list[type[Model]]: """Sort models so children come before parents (safe delete order). Uses Kahn's algorithm on FK dependencies. Models that depend on others via ForeignKey are placed *before* the models they reference, ensuring child rows are deleted before parent rows. """ from tortoise.fields.relational import ForeignKeyFieldInstance model_set = set(models) # Build adjacency for delete order: parent -> children that must be deleted first # If Event has FK to Tournament, then Tournament depends on Event being deleted first deps: dict[type[Model], set[type[Model]]] = {m: set() for m in models} for model in models: for field in model._meta.fields_map.values(): if isinstance(field, ForeignKeyFieldInstance): related = field.related_model if related in model_set and related is not model: deps[related].add(model) # Kahn's algorithm — emit models whose deps are already emitted sorted_models: list[type[Model]] = [] no_deps = [m for m in models if not deps[m]] while no_deps: m = no_deps.pop() sorted_models.append(m) for other in models: deps[other].discard(m) if not deps[other] and other not in sorted_models and other not in no_deps: no_deps.append(other) # Append any remaining (circular deps — fallback) for m in models: if m not in sorted_models: sorted_models.append(m) return sorted_models _FT = TypeVar("_FT", bound=Callable[..., typing.Any])
[docs] def requireCapability( connection_name: str = "models", **conditions: typing.Any ) -> Callable[[_FT], _FT]: """ Skip a test if the required capabilities are not matched. .. note:: The database must be initialized *before* the decorated test runs. Usage: .. code-block:: python3 @requireCapability(dialect='sqlite') @pytest.mark.asyncio async def test_run_sqlite_only(db): ... Or to conditionally skip a class: .. code-block:: python3 @requireCapability(dialect='sqlite') class TestSqlite: @pytest.mark.asyncio async def test_something(self, db): ... :param connection_name: name of the connection to retrieve capabilities from. :param conditions: capability tests which must all pass for the test to run. """ def decorator(test_item: _FT) -> _FT: if not isinstance(test_item, type): def check_capabilities() -> None: db = get_connection(connection_name) for key, val in conditions.items(): if getattr(db.capabilities, key) != val: raise SkipTest(f"Capability {key} != {val}") if inspect.iscoroutinefunction(test_item): @wraps(test_item) async def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_capabilities() return await test_item(*args, **kwargs) else: @wraps(test_item) def skip_wrapper(*args: typing.Any, **kwargs: typing.Any) -> typing.Any: check_capabilities() return test_item(*args, **kwargs) return cast(_FT, skip_wrapper) # Assume a class is decorated funcs = { var: f for var in dir(test_item) if var.startswith("test_") and callable(f := getattr(test_item, var)) } for name, func in funcs.items(): setattr( test_item, name, requireCapability(connection_name=connection_name, **conditions)(func), ) return test_item return decorator
@typing.overload def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ... @typing.overload def init_memory_sqlite(models: AsyncFunc) -> AsyncFunc: ... def init_memory_sqlite( models: ModulesConfigType | AsyncFunc | None = None, ) -> AsyncFunc | AsyncFuncDeco: """ Decorator for initializing Tortoise with an in-memory SQLite database. This is useful for simple scripts and examples that need a quick database setup. :param models: List of modules to load models from. Defaults to ["__main__"]. Usage: .. code-block:: python3 from tortoise import fields, models, run_async from tortoise.contrib.test import init_memory_sqlite class MyModel(models.Model): id = fields.IntField(primary_key=True) name = fields.TextField() @init_memory_sqlite async def run(): obj = await MyModel.create(name='') assert obj.id == 1 if __name__ == '__main__': run_async(run()) Custom models example: .. code-block:: python3 @init_memory_sqlite(models=['app.models', 'aerich.models']) async def run(): ... """ def wrapper(func: AsyncFunc, ms: list[str]): @wraps(func) async def runner(*args, **kwargs) -> T: await Tortoise.init(db_url=MEMORY_SQLITE, modules={"models": ms}) await Tortoise.generate_schemas() return await func(*args, **kwargs) return runner default_models = ["__main__"] if inspect.iscoroutinefunction(models): return wrapper(models, default_models) if models is None: models = default_models elif isinstance(models, str): models = [models] else: models = cast(list, models) return partial(wrapper, ms=models)