import asyncio
import inspect
import os as _os
import sys
import typing
import unittest
from asyncio.events import AbstractEventLoop
from functools import partial, wraps
from types import ModuleType
from typing import (
Any,
Callable,
Coroutine,
Iterable,
List,
Optional,
TypeVar,
Union,
cast,
)
from unittest import SkipTest, expectedFailure, skip, skipIf, skipUnless
from tortoise import Model, Tortoise, connections
from tortoise.backends.base.config_generator import generate_config as _generate_config
from tortoise.exceptions import DBConnectionError, OperationalError
if sys.version_info >= (3, 10):
from typing import ParamSpec
else:
from typing_extensions import ParamSpec
__all__ = (
"MEMORY_SQLITE",
"SimpleTestCase",
"TestCase",
"TruncationTestCase",
"IsolatedTestCase",
"getDBConfig",
"requireCapability",
"env_initializer",
"initializer",
"finalizer",
"SkipTest",
"expectedFailure",
"skip",
"skipIf",
"skipUnless",
"init_memory_sqlite",
)
_TORTOISE_TEST_DB = "sqlite://:memory:"
# pylint: disable=W0201
expectedFailure.__doc__ = """
Mark test as expecting failure.
On success it will be marked as unexpected success.
"""
_CONFIG: dict = {}
_CONNECTIONS: dict = {}
_LOOP: AbstractEventLoop = None # type: ignore
_MODULES: Iterable[Union[str, ModuleType]] = []
_CONN_CONFIG: dict = {}
[docs]def getDBConfig(app_label: str, modules: Iterable[Union[str, ModuleType]]) -> dict:
"""
DB Config factory, for use in testing.
:param app_label: Label of the app (must be distinct for multiple apps).
:param modules: List of modules to look for models in.
"""
return _generate_config(
_TORTOISE_TEST_DB,
app_modules={app_label: modules},
testing=True,
connection_label=app_label,
)
async def _init_db(config: dict) -> None:
# Placing init outside the try block since it doesn't
# establish connections to the DB eagerly.
await Tortoise.init(config)
try:
await Tortoise._drop_databases()
except (DBConnectionError, OperationalError): # pragma: nocoverage
pass
await Tortoise.init(config, _create_db=True)
await Tortoise.generate_schemas(safe=False)
def _restore_default() -> None:
Tortoise.apps = {}
connections._get_storage().update(_CONNECTIONS.copy())
connections._db_config = _CONN_CONFIG.copy()
Tortoise._init_apps(_CONFIG["apps"])
Tortoise._inited = True
[docs]def initializer(
modules: Iterable[Union[str, ModuleType]],
db_url: Optional[str] = None,
app_label: str = "models",
loop: Optional[AbstractEventLoop] = None,
) -> None:
"""
Sets up the DB for testing. Must be called as part of test environment setup.
:param modules: List of modules to look for models in.
:param db_url: The db_url, defaults to ``sqlite://:memory``.
:param app_label: The name of the APP to initialise the modules in, defaults to "models"
:param loop: Optional event loop.
"""
# pylint: disable=W0603
global _CONFIG
global _CONNECTIONS
global _LOOP
global _TORTOISE_TEST_DB
global _MODULES
global _CONN_CONFIG
_MODULES = modules
if db_url is not None: # pragma: nobranch
_TORTOISE_TEST_DB = db_url
_CONFIG = getDBConfig(app_label=app_label, modules=_MODULES)
loop = loop or asyncio.get_event_loop()
_LOOP = loop
loop.run_until_complete(_init_db(_CONFIG))
_CONNECTIONS = connections._copy_storage()
_CONN_CONFIG = connections.db_config.copy()
connections._clear_storage()
connections.db_config.clear()
Tortoise.apps = {}
Tortoise._inited = False
[docs]def finalizer() -> None:
"""
Cleans up the DB after testing. Must be called as part of the test environment teardown.
"""
_restore_default()
loop = _LOOP
loop.run_until_complete(Tortoise._drop_databases())
[docs]def env_initializer() -> None: # pragma: nocoverage
"""
Calls ``initializer()`` with parameters mapped from environment variables.
``TORTOISE_TEST_MODULES``:
A comma-separated list of modules to include *(required)*
``TORTOISE_TEST_APP``:
The name of the APP to initialise the modules in *(optional)*
If not provided, it will default to "models".
``TORTOISE_TEST_DB``:
The db_url of the test db. *(optional*)
If not provided, it will default to an in-memory SQLite DB.
"""
modules = str(_os.environ.get("TORTOISE_TEST_MODULES", "tests.testmodels")).split(",")
db_url = _os.environ.get("TORTOISE_TEST_DB", "sqlite://:memory:")
app_label = _os.environ.get("TORTOISE_TEST_APP", "models")
if not modules: # pragma: nocoverage
raise Exception("TORTOISE_TEST_MODULES envvar not defined")
initializer(modules, db_url=db_url, app_label=app_label)
[docs]class SimpleTestCase(unittest.IsolatedAsyncioTestCase):
"""
The Tortoise base test class.
This will ensure that your DB environment has a test double set up for use.
An asyncio capable test class that provides some helper functions.
Will run any ``test_*()`` function either as sync or async, depending
on the signature of the function.
If you specify ``async test_*()`` then it will run it in an event loop.
Based on `asynctest <http://asynctest.readthedocs.io/>`_
"""
def _setupAsyncioRunner(self) -> None:
if hasattr(asyncio, "Runner"): # For python3.11+
runner = asyncio.Runner(debug=True, loop_factory=asyncio.get_event_loop)
self._asyncioRunner = runner
def _tearDownAsyncioRunner(self) -> None:
# Override runner tear down to avoid eventloop closing before testing completed.
pass
async def _setUpDB(self) -> None:
pass
async def _tearDownDB(self) -> None:
pass
def _setupAsyncioLoop(self):
loop = asyncio.get_event_loop()
loop.set_debug(True)
self._asyncioTestLoop = loop
fut = loop.create_future()
self._asyncioCallsTask = loop.create_task(self._asyncioLoopRunner(fut)) # type: ignore
loop.run_until_complete(fut)
def _tearDownAsyncioLoop(self):
loop = self._asyncioTestLoop
self._asyncioTestLoop = None # type: ignore
self._asyncioCallsQueue.put_nowait(None) # type: ignore
loop.run_until_complete(self._asyncioCallsQueue.join()) # type: ignore
[docs] async def asyncSetUp(self) -> None:
await self._setUpDB()
def _reset_conn_state(self) -> None:
# clearing the storage and db config
connections._clear_storage()
connections.db_config.clear()
[docs] async def asyncTearDown(self) -> None:
await self._tearDownDB()
self._reset_conn_state()
Tortoise.apps = {}
Tortoise._inited = False
[docs] def assertListSortEqual(
self, list1: List[Any], list2: List[Any], msg: Any = ..., sorted_key: Optional[str] = None
) -> None:
if isinstance(list1[0], Model):
super().assertListEqual(
sorted(list1, key=lambda x: x.pk), sorted(list2, key=lambda x: x.pk), msg=msg
)
elif isinstance(list1[0], dict) and sorted_key:
super().assertListEqual(
sorted(list1, key=lambda x: x[sorted_key]),
sorted(list2, key=lambda x: x[sorted_key]),
msg=msg,
)
else:
super().assertListEqual(sorted(list1), sorted(list2), msg=msg)
[docs]class IsolatedTestCase(SimpleTestCase):
"""
An asyncio capable test class that will ensure that an isolated test db
is available for each test.
Use this if your test needs perfect isolation.
Note to use ``{}`` as a string-replacement parameter, for your DB_URL.
That will create a randomised database name.
It will create and destroy a new DB instance for every test.
This is obviously slow, but guarantees a fresh DB.
If you define a ``tortoise_test_modules`` list, it overrides the DB setup module for the tests.
"""
tortoise_test_modules: Iterable[Union[str, ModuleType]] = []
async def _setUpDB(self) -> None:
await super()._setUpDB()
config = getDBConfig(app_label="models", modules=self.tortoise_test_modules or _MODULES)
await Tortoise.init(config, _create_db=True)
await Tortoise.generate_schemas(safe=False)
async def _tearDownDB(self) -> None:
await Tortoise._drop_databases()
[docs]class TruncationTestCase(SimpleTestCase):
"""
An asyncio capable test class that will truncate the tables after a test.
Use this when your tests contain transactions.
This is slower than ``TestCase`` but faster than ``IsolatedTestCase``.
Note that usage of this does not guarantee that auto-number-pks will be reset to 1.
"""
async def _setUpDB(self) -> None:
await super()._setUpDB()
_restore_default()
async def _tearDownDB(self) -> None:
_restore_default()
# TODO: This is a naive implementation: Will fail to clear M2M and non-cascade foreign keys
for app in Tortoise.apps.values():
for model in app.values():
quote_char = model._meta.db.query_class._builder().QUOTE_CHAR
await model._meta.db.execute_script(
f"DELETE FROM {quote_char}{model._meta.db_table}{quote_char}" # nosec
)
await super()._tearDownDB()
class TransactionTestContext:
__slots__ = ("connection", "connection_name", "token", "uses_pool")
def __init__(self, connection) -> None:
self.connection = connection
self.connection_name = connection.connection_name
self.uses_pool = hasattr(self.connection._parent, "_pool")
async def ensure_connection(self) -> None:
is_conn_established = self.connection._connection is not None
if self.uses_pool:
is_conn_established = self.connection._parent._pool is not None
# If the underlying pool/connection hasn't been established then
# first create the pool/connection
if not is_conn_established:
await self.connection._parent.create_connection(with_db=True)
if self.uses_pool:
self.connection._connection = await self.connection._parent._pool.acquire()
else:
self.connection._connection = self.connection._parent._connection
async def __aenter__(self):
await self.ensure_connection()
self.token = connections.set(self.connection_name, self.connection)
await self.connection.start()
return self.connection
async def __aexit__(self, exc_type, exc_val, exc_tb) -> None:
await self.connection.rollback()
if self.uses_pool:
await self.connection._parent._pool.release(self.connection._connection)
connections.reset(self.token)
[docs]class TestCase(TruncationTestCase):
"""
An asyncio capable test class that will ensure that each test will be run at
separate transaction that will rollback on finish.
This is a fast test runner. Don't use it if your test uses transactions.
"""
[docs] async def asyncSetUp(self) -> None:
await super().asyncSetUp()
self._db = connections.get("models")
self._transaction = TransactionTestContext(self._db._in_transaction().connection)
await self._transaction.__aenter__() # type: ignore
[docs] async def asyncTearDown(self) -> None:
await self._transaction.__aexit__(None, None, None)
await super().asyncTearDown()
async def _tearDownDB(self) -> None:
if self._db.capabilities.supports_transactions:
_restore_default()
else:
await super()._tearDownDB()
[docs]def requireCapability(connection_name: str = "models", **conditions: Any) -> Callable:
"""
Skip a test if the required capabilities are not matched.
.. note::
The database must be initialized *before* the decorated test runs.
Usage:
.. code-block:: python3
@requireCapability(dialect='sqlite')
async def test_run_sqlite_only(self):
...
Or to conditionally skip a class:
.. code-block:: python3
@requireCapability(dialect='sqlite')
class TestSqlite(test.TestCase):
...
:param connection_name: name of the connection to retrieve capabilities from.
:param conditions: capability tests which must all pass for the test to run.
"""
def decorator(test_item):
if not isinstance(test_item, type):
def check_capabilities() -> None:
db = connections.get(connection_name)
for key, val in conditions.items():
if getattr(db.capabilities, key) != val:
raise SkipTest(f"Capability {key} != {val}")
if hasattr(asyncio, "Runner") and inspect.iscoroutinefunction(test_item):
# For python3.11+
@wraps(test_item)
async def skip_wrapper(*args, **kwargs):
check_capabilities()
return await test_item(*args, **kwargs)
else:
@wraps(test_item)
def skip_wrapper(*args, **kwargs):
check_capabilities()
return test_item(*args, **kwargs)
return skip_wrapper
# Assume a class is decorated
funcs = {
var: f
for var in dir(test_item)
if var.startswith("test_") and callable(f := getattr(test_item, var))
}
for name, func in funcs.items():
setattr(
test_item,
name,
requireCapability(connection_name=connection_name, **conditions)(func),
)
return test_item
return decorator
T = TypeVar("T")
P = ParamSpec("P")
AsyncFunc = Callable[P, Coroutine[None, None, T]]
AsyncFuncDeco = Callable[..., AsyncFunc]
ModulesConfigType = Union[str, List[str]]
MEMORY_SQLITE = "sqlite://:memory:"
@typing.overload
def init_memory_sqlite(models: Union[ModulesConfigType, None] = None) -> AsyncFuncDeco: ...
@typing.overload
def init_memory_sqlite(models: AsyncFunc) -> AsyncFunc: ...
[docs]def init_memory_sqlite(
models: Union[ModulesConfigType, AsyncFunc, None] = None
) -> Union[AsyncFunc, AsyncFuncDeco]:
"""
For single file style to run code with memory sqlite
:param models: list_of_modules that should be discovered for models, default to ['__main__'].
Usage:
.. code-block:: python3
from tortoise import fields, models, run_async
from tortoise.contrib.test import init_memory_sqlite
class MyModel(models.Model):
id = fields.IntField(primary_key=True)
name = fields.TextField()
@init_memory_sqlite
async def run():
obj = await MyModel.create(name='')
assert obj.id == 1
if __name__ == '__main__'
run_async(run)
Custom models example:
.. code-block:: python3
@init_memory_sqlite(models=['app.models', 'aerich.models'])
async def run():
...
"""
def wrapper(func: AsyncFunc, ms: List[str]):
@wraps(func)
async def runner(*args, **kwargs) -> T:
await Tortoise.init(db_url=MEMORY_SQLITE, modules={"models": ms})
await Tortoise.generate_schemas()
return await func(*args, **kwargs)
return runner
default_models = ["__main__"]
if inspect.iscoroutinefunction(models):
return wrapper(models, default_models)
if models is None:
models = default_models
elif isinstance(models, str):
models = [models]
else:
models = cast(list, models)
return partial(wrapper, ms=models)