diff --git a/pytest_django/fixtures.py b/pytest_django/fixtures.py index adc5dea7f..024aaa017 100644 --- a/pytest_django/fixtures.py +++ b/pytest_django/fixtures.py @@ -16,6 +16,7 @@ from .lazy_django import get_django_version, skip_if_no_django __all__ = ['django_db_setup', 'db', 'transactional_db', 'admin_user', + 'django_db_testcase', 'django_transactional_db_testcase', 'django_user_model', 'django_username_field', 'client', 'admin_client', 'rf', 'settings', 'live_server', '_live_server_helper', 'django_assert_num_queries'] @@ -110,6 +111,18 @@ def teardown_database(): request.addfinalizer(teardown_database) +@pytest.fixture +def django_db_testcase(request): + from django.test import TestCase + return TestCase + + +@pytest.fixture +def django_transactional_db_testcase(request): + from django.test import TransactionTestCase + return TransactionTestCase + + def _django_db_fixture_helper(transactional, request, django_db_blocker): if is_django_unittest(request): return @@ -121,10 +134,9 @@ def _django_db_fixture_helper(transactional, request, django_db_blocker): django_db_blocker.unblock() request.addfinalizer(django_db_blocker.restore) - if transactional: - from django.test import TransactionTestCase as django_case - else: - from django.test import TestCase as django_case + testcase_class_fixture = ('django_transactional_db_testcase' + if transactional else 'django_db_testcase') + django_case = getfixturevalue(request, testcase_class_fixture) test_case = django_case(methodName='__init__') test_case._pre_setup() diff --git a/pytest_django/plugin.py b/pytest_django/plugin.py index 4c6961a36..f7ccfbda3 100644 --- a/pytest_django/plugin.py +++ b/pytest_django/plugin.py @@ -32,6 +32,8 @@ from .fixtures import rf # noqa from .fixtures import settings # noqa from .fixtures import transactional_db # noqa +from .fixtures import django_db_testcase # noqa +from .fixtures import django_transactional_db_testcase # noqa from .pytest_compat import getfixturevalue from .lazy_django import (django_settings_is_configured,