summaryrefslogtreecommitdiffhomepage
path: root/libs/sqlalchemy/testing/asyncio.py
blob: 10d3d079d7f5259bedc60a8931ed86f09f9abf6c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
# testing/asyncio.py
# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors
# <see AUTHORS file>
#
# This module is part of SQLAlchemy and is released under
# the MIT License: https://www.opensource.org/licenses/mit-license.php
# mypy: ignore-errors


# functions and wrappers to run tests, fixtures, provisioning and
# setup/teardown in an asyncio event loop, conditionally based on the
# current DB driver being used for a test.

# note that SQLAlchemy's asyncio integration also supports a method
# of running individual asyncio functions inside of separate event loops
# using "async_fallback" mode; however running whole functions in the event
# loop is a more accurate test for how SQLAlchemy's asyncio features
# would run in the real world.


from __future__ import annotations

from functools import wraps
import inspect

from . import config
from ..util.concurrency import _util_async_run
from ..util.concurrency import _util_async_run_coroutine_function

# may be set to False if the
# --disable-asyncio flag is passed to the test runner.
ENABLE_ASYNCIO = True


def _run_coroutine_function(fn, *args, **kwargs):
    return _util_async_run_coroutine_function(fn, *args, **kwargs)


def _assume_async(fn, *args, **kwargs):
    """Run a function in an asyncio loop unconditionally.

    This function is used for provisioning features like
    testing a database connection for server info.

    Note that for blocking IO database drivers, this means they block the
    event loop.

    """

    if not ENABLE_ASYNCIO:
        return fn(*args, **kwargs)

    return _util_async_run(fn, *args, **kwargs)


def _maybe_async_provisioning(fn, *args, **kwargs):
    """Run a function in an asyncio loop if any current drivers might need it.

    This function is used for provisioning features that take
    place outside of a specific database driver being selected, so if the
    current driver that happens to be used for the provisioning operation
    is an async driver, it will run in asyncio and not fail.

    Note that for blocking IO database drivers, this means they block the
    event loop.

    """
    if not ENABLE_ASYNCIO:
        return fn(*args, **kwargs)

    if config.any_async:
        return _util_async_run(fn, *args, **kwargs)
    else:
        return fn(*args, **kwargs)


def _maybe_async(fn, *args, **kwargs):
    """Run a function in an asyncio loop if the current selected driver is
    async.

    This function is used for test setup/teardown and tests themselves
    where the current DB driver is known.


    """
    if not ENABLE_ASYNCIO:

        return fn(*args, **kwargs)

    is_async = config._current.is_async

    if is_async:
        return _util_async_run(fn, *args, **kwargs)
    else:
        return fn(*args, **kwargs)


def _maybe_async_wrapper(fn):
    """Apply the _maybe_async function to an existing function and return
    as a wrapped callable, supporting generator functions as well.

    This is currently used for pytest fixtures that support generator use.

    """

    if inspect.isgeneratorfunction(fn):
        _stop = object()

        def call_next(gen):
            try:
                return next(gen)
                # can't raise StopIteration in an awaitable.
            except StopIteration:
                return _stop

        @wraps(fn)
        def wrap_fixture(*args, **kwargs):
            gen = fn(*args, **kwargs)
            while True:
                value = _maybe_async(call_next, gen)
                if value is _stop:
                    break
                yield value

    else:

        @wraps(fn)
        def wrap_fixture(*args, **kwargs):
            return _maybe_async(fn, *args, **kwargs)

    return wrap_fixture