from __future__ import annotations
import asyncio
import inspect
import os as _os
import sys
import typing
import unittest
from collections.abc import Callable, Coroutine, Iterable
from functools import partial, wraps
from types import ModuleType
from typing import TYPE_CHECKING, Any, TypeVar, Union, cast
from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless
from tortoise import Model, Tortoise, connections
from tortoise.backends.base.config_generator import generate_config as _generate_config
from tortoise.exceptions import DBConnectionError, OperationalError
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
if TYPE_CHECKING:
from asyncio.events import AbstractEventLoop
__all__ = (
"MEMORY_SQLITE",
"SimpleTestCase",
"TestCase",
"TruncationTestCase",
"IsolatedTestCase",
"getDBConfig",
"requireCapability",
"env_initializer",
"initializer",
"finalizer",
"SkipTest",
"expectedFailure",
"skip",
"skipIf",
"skipUnless",
"init_memory_sqlite",
)
_TORTOISE_TEST_DB = "sqlite://:memory:"
# pylint: disable=W0201
expectedFailure.__doc__ = """
Mark test as expecting failure.
On success it will be marked as unexpected success.
"""
_CONFIG: dict = {}
_CONNECTIONS: dict = {}
_LOOP: AbstractEventLoop = None # type: ignore
_MODULES: Iterable[str | ModuleType] = []
_CONN_CONFIG: dict = {}
[docs]
def getDBConfig(app_label: str, modules: Iterable[str | ModuleType]) -> dict:
"""
DB Config factory, for use in testing.
:param app_label: Label of the app (must be distinct for multiple apps).
:param modules: List of modules to look for models in.
"""
return _generate_config(
_TORTOISE_TEST_DB,
app_modules={app_label: modules},
testing=True,
connection_label=app_label,
)
async def _init_db(config: dict) -> None:
# Placing init outside the try block since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config)
try:
await Tortoise._drop_databases()
except (DBConnectionError, OperationalError): # pragma: nocoverage
pass
await Tortoise.init(config, _create_db=True)
await Tortoise.generate_schemas(safe=False)
def _restore_default() -> None:
Tortoise.apps = {}
connections._get_storage().update(_CONNECTIONS.copy())
connections._db_config = _CONN_CONFIG.copy()
Tortoise._init_apps(_CONFIG["apps"])
Tortoise._inited = True
async def truncate_all_models() -> None:
# TODO: This is a naive implementation: Will fail to clear M2M and non-cascade foreign keys
for app in Tortoise.apps.values():
for model in app.values():
quote_char = model._meta.db.query_class.SQL_CONTEXT.quote_char
await model._meta.db.execute_script(
f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}" # nosec
)
[docs]
def initializer(
modules: Iterable[str | ModuleType],
db_url: str | None = None,
app_label: str = "models",
loop: AbstractEventLoop | None = None,
) -> None:
"""
Sets up the DB for testing. Must be called as part of test environment setup.
:param modules: List of modules to look for models in.
:param db_url: The db_url, defaults to ``sqlite://:memory``.
:param app_label: The name of the APP to initialise the modules in, defaults to "models"
:param loop: Optional event loop.
"""
# pylint: disable=W0603
global _CONFIG
global _CONNECTIONS
global _LOOP
global _TORTOISE_TEST_DB
global _MODULES
global _CONN_CONFIG
_MODULES = modules
if db_url is not None: # pragma: nobranch
_TORTOISE_TEST_DB = db_url
_CONFIG = getDBConfig(app_label=app_label, modules=_MODULES)
if not loop:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
_LOOP = loop
loop.run_until_complete(_init_db(_CONFIG))
_CONNECTIONS = connections._copy_storage()
_CONN_CONFIG = connections.db_config.copy()
connections._clear_storage()
connections.db_config.clear()
Tortoise.apps = {}
Tortoise._inited = False
[docs]
def finalizer() -> None:
"""
Cleans up the DB after testing. Must be called as part of the test environment teardown.
"""
_restore_default()
loop = _LOOP
loop.run_until_complete(Tortoise._drop_databases())
[docs]
def env_initializer() -> None: # pragma: nocoverage
"""
Calls ``initializer()`` with parameters mapped from environment variables.
``TORTOISE_TEST_MODULES``:
A comma-separated list of modules to include *(required)*
``TORTOISE_TEST_APP``:
The name of the APP to initialise the modules in *(optional)*
If not provided, it will default to "models".
``TORTOISE_TEST_DB``:
The db_url of the test db. *(optional*)
If not provided, it will default to an in-memory SQLite DB.
"""
modules = str(_os.environ.get("TORTOISE_TEST_MODULES", "tests.testmodels")).split(",")
db_url = _os.environ.get("TORTOISE_TEST_DB", "sqlite://:memory:")
app_label = _os.environ.get("TORTOISE_TEST_APP", "models")
if not modules: # pragma: nocoverage
raise Exception("TORTOISE_TEST_MODULES envvar not defined")
initializer(modules, db_url=db_url, app_label=app_label)
[docs]
class SimpleTestCase(unittest.IsolatedAsyncioTestCase):
"""
The Tortoise base test class.
This will ensure that your DB environment has a test double set up for use.
An asyncio capable test class that provides some helper functions.
Will run any ``test_*()`` function either as sync or async, depending
on the signature of the function.
If you specify ``async test_*()`` then it will run it in an event loop.
Based on `asynctest <http://asynctest.readthedocs.io/>`_
"""
def _setupAsyncioRunner(self) -> None:
if hasattr(asyncio, "Runner"): # For python3.11+
runner = asyncio.Runner(debug=True, loop_factory=asyncio.get_event_loop)
self._asyncioRunner = runner
def _tearDownAsyncioRunner(self) -> None:
# Override runner tear down to avoid eventloop closing before testing completed.
pass
async def _setUpDB(self) -> None:
pass
async def _tearDownDB(self) -> None:
pass
def _setupAsyncioLoop(self):
loop = asyncio.get_event_loop()
loop.set_debug(True)
self._asyncioTestLoop = loop
fut = loop.create_future()
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) # type: ignore
loop.run_until_complete(fut)
def _tearDownAsyncioLoop(self):
loop = self._asyncioTestLoop
self._asyncioTestLoop = None # type: ignore
self._asyncioCallsQueue.put_nowait(None) # type: ignore
loop.run_until_complete(self._asyncioCallsQueue.join()) # type: ignore
[docs]
async def asyncSetUp(self) -> None:
await self._setUpDB()
def _reset_conn_state(self) -> None:
# clearing the storage and db config
connections._clear_storage()
connections.db_config.clear()
[docs]
async def asyncTearDown(self) -> None:
await self._tearDownDB()
self._reset_conn_state()
Tortoise.apps = {}
Tortoise._inited = False
[docs]
def assertListSortEqual(
self, list1: list[Any], list2: list[Any], msg: Any = ..., sorted_key: str | None = None
) -> None:
if isinstance(list1[0], Model):
super().assertListEqual(
sorted(list1, key=lambda x: x.pk), sorted(list2, key=lambda x: x.pk), msg=msg
)
elif isinstance(list1[0], dict) and sorted_key:
super().assertListEqual(
sorted(list1, key=lambda x: x[sorted_key]),
sorted(list2, key=lambda x: x[sorted_key]),
msg=msg,
)
else:
super().assertListEqual(sorted(list1), sorted(list2), msg=msg)
[docs]
class IsolatedTestCase(SimpleTestCase):
"""
An asyncio capable test class that will ensure that an isolated test db
is available for each test.
Use this if your test needs perfect isolation.
Note to use ``{}`` as a string-replacement parameter, for your DB_URL.
That will create a randomised database name.
It will create and destroy a new DB instance for every test.
This is obviously slow, but guarantees a fresh DB.
If you define a ``tortoise_test_modules`` list, it overrides the DB setup module for the tests.
"""
tortoise_test_modules: Iterable[str | ModuleType] = []
async def _setUpDB(self) -> None:
await super()._setUpDB()
config = getDBConfig(app_label="models", modules=self.tortoise_test_modules or _MODULES)
await Tortoise.init(config, _create_db=True)
await Tortoise.generate_schemas(safe=False)
async def _tearDownDB(self) -> None:
await Tortoise._drop_databases()
[docs]
class TruncationTestCase(SimpleTestCase):
"""
An asyncio capable test class that will truncate the tables after a test.
Use this when your tests contain transactions.
This is slower than ``TestCase`` but faster than ``IsolatedTestCase``.
Note that usage of this does not guarantee that auto-number-pks will be reset to 1.
"""
async def _setUpDB(self) -> None:
await super()._setUpDB()
_restore_default()
async def _tearDownDB(self) -> None:
_restore_default()
await truncate_all_models()
await super()._tearDownDB()
class _RollbackException(Exception):
pass
[docs]
class TestCase(TruncationTestCase):
"""
An asyncio capable test class that will ensure that each test will be run at
separate transaction that will rollback on finish.
This is a fast test runner. Don't use it if your test uses transactions.
"""
[docs]
async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self._db = connections.get("models")
self._transaction = self._db._in_transaction()
await self._transaction.__aenter__()
[docs]
async def asyncTearDown(self) -> None:
# this will cause a rollback
await self._transaction.__aexit__(_RollbackException, _RollbackException(), None)
await super().asyncTearDown()
async def _tearDownDB(self) -> None:
if self._db.capabilities.supports_transactions:
_restore_default()
else:
await super()._tearDownDB()
[docs]
def requireCapability(connection_name: str = "models", **conditions: Any) -> Callable:
"""
Skip a test if the required capabilities are not matched.
.. note::
The database must be initialized *before* the decorated test runs.
Usage:
.. code-block:: python3
@requireCapability(dialect='sqlite')
async def test_run_sqlite_only(self):
...
Or to conditionally skip a class:
.. code-block:: python3
@requireCapability(dialect='sqlite')
class TestSqlite(test.TestCase):
...
:param connection_name: name of the connection to retrieve capabilities from.
:param conditions: capability tests which must all pass for the test to run.
"""
def decorator(test_item):
if not isinstance(test_item, type):
def check_capabilities() -> None:
db = connections.get(connection_name)
for key, val in conditions.items():
if getattr(db.capabilities, key) != val:
raise SkipTest(f"Capability {key} != {val}")
if hasattr(asyncio, "Runner") and inspect.iscoroutinefunction(test_item):
# For python3.11+
@wraps(test_item)
async def skip_wrapper(*args, **kwargs):
check_capabilities()
return await test_item(*args, **kwargs)
else:
@wraps(test_item)
def skip_wrapper(*args, **kwargs):
check_capabilities()
return test_item(*args, **kwargs)
return skip_wrapper
# Assume a class is decorated
funcs = {
var: f
for var in dir(test_item)
if var.startswith("test_") and callable(f := getattr(test_item, var))
}
for name, func in funcs.items():
setattr(
test_item,
name,
requireCapability(connection_name=connection_name, **conditions)(func),
)
return test_item
return decorator
T = TypeVar("T")
P = ParamSpec("P")
AsyncFunc = Callable[P, Coroutine[None, None, T]]
AsyncFuncDeco = Callable[..., AsyncFunc]
ModulesConfigType = Union[str, list[str]]
MEMORY_SQLITE = "sqlite://:memory:"
@typing.overload
def init_memory_sqlite(models: ModulesConfigType | None = None) -> AsyncFuncDeco: ...
@typing.overload
def init_memory_sqlite(models: AsyncFunc) -> AsyncFunc: ...
[docs]
def init_memory_sqlite(
models: ModulesConfigType | AsyncFunc | None = None,
) -> AsyncFunc | AsyncFuncDeco:
"""
For single file style to run code with memory sqlite
:param models: list_of_modules that should be discovered for models, default to ['__main__'].
Usage:
.. code-block:: python3
from tortoise import fields, models, run_async
from tortoise.contrib.test import init_memory_sqlite
class MyModel(models.Model):
id = fields.IntField(primary_key=True)
name = fields.TextField()
@init_memory_sqlite
async def run():
obj = await MyModel.create(name='')
assert obj.id == 1
if __name__ == '__main__'
run_async(run)
Custom models example:
.. code-block:: python3
@init_memory_sqlite(models=['app.models', 'aerich.models'])
async def run():
...
"""
def wrapper(func: AsyncFunc, ms: list[str]):
@wraps(func)
async def runner(*args, **kwargs) -> T:
await Tortoise.init(db_url=MEMORY_SQLITE, modules={"models": ms})
await Tortoise.generate_schemas()
return await func(*args, **kwargs)
return runner
default_models = ["__main__"]
if inspect.iscoroutinefunction(models):
return wrapper(models, default_models)
if models is None:
models = default_models
elif isinstance(models, str):
models = [models]
else:
models = cast(list, models)
return partial(wrapper, ms=models)