diff options
Diffstat (limited to 'libs/sqlalchemy/testing/engines.py')
-rw-r--r-- | libs/sqlalchemy/testing/engines.py | 470 |
1 files changed, 470 insertions, 0 deletions
diff --git a/libs/sqlalchemy/testing/engines.py b/libs/sqlalchemy/testing/engines.py new file mode 100644 index 000000000..fa70de6f0 --- /dev/null +++ b/libs/sqlalchemy/testing/engines.py @@ -0,0 +1,470 @@ +# testing/engines.py +# Copyright (C) 2005-2023 the SQLAlchemy authors and contributors +# <see AUTHORS file> +# +# This module is part of SQLAlchemy and is released under +# the MIT License: https://www.opensource.org/licenses/mit-license.php +# mypy: ignore-errors + + +from __future__ import annotations + +import collections +import re +import typing +from typing import Any +from typing import Dict +from typing import Optional +import warnings +import weakref + +from . import config +from .util import decorator +from .util import gc_collect +from .. import event +from .. import pool +from ..util import await_only +from ..util.typing import Literal + + +if typing.TYPE_CHECKING: + from ..engine import Engine + from ..engine.url import URL + from ..ext.asyncio import AsyncEngine + + +class ConnectionKiller: + def __init__(self): + self.proxy_refs = weakref.WeakKeyDictionary() + self.testing_engines = collections.defaultdict(set) + self.dbapi_connections = set() + + def add_pool(self, pool): + event.listen(pool, "checkout", self._add_conn) + event.listen(pool, "checkin", self._remove_conn) + event.listen(pool, "close", self._remove_conn) + event.listen(pool, "close_detached", self._remove_conn) + # note we are keeping "invalidated" here, as those are still + # opened connections we would like to roll back + + def _add_conn(self, dbapi_con, con_record, con_proxy): + self.dbapi_connections.add(dbapi_con) + self.proxy_refs[con_proxy] = True + + def _remove_conn(self, dbapi_conn, *arg): + self.dbapi_connections.discard(dbapi_conn) + + def add_engine(self, engine, scope): + self.add_pool(engine.pool) + + assert scope in ("class", "global", "function", "fixture") + self.testing_engines[scope].add(engine) + + def _safe(self, fn): + try: + fn() + except Exception as e: + warnings.warn( + "testing_reaper couldn't rollback/close connection: %s" % e + ) + + def rollback_all(self): + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self._safe(rec.rollback) + + def checkin_all(self): + # run pool.checkin() for all ConnectionFairy instances we have + # tracked. + + for rec in list(self.proxy_refs): + if rec is not None and rec.is_valid: + self.dbapi_connections.discard(rec.dbapi_connection) + self._safe(rec._checkin) + + # for fairy refs that were GCed and could not close the connection, + # such as asyncio, roll back those remaining connections + for con in self.dbapi_connections: + self._safe(con.rollback) + self.dbapi_connections.clear() + + def close_all(self): + self.checkin_all() + + def prepare_for_drop_tables(self, connection): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + from . import provision + + provision.prepare_for_drop_tables(connection.engine.url, connection) + + def _drop_testing_engines(self, scope): + eng = self.testing_engines[scope] + for rec in list(eng): + for proxy_ref in list(self.proxy_refs): + if proxy_ref is not None and proxy_ref.is_valid: + if ( + proxy_ref._pool is not None + and proxy_ref._pool is rec.pool + ): + self._safe(proxy_ref._checkin) + + if hasattr(rec, "sync_engine"): + await_only(rec.dispose()) + else: + rec.dispose() + eng.clear() + + def after_test(self): + self._drop_testing_engines("function") + + def after_test_outside_fixtures(self, test): + # don't do aggressive checks for third party test suites + if not config.bootstrapped_as_sqlalchemy: + return + + if test.__class__.__leave_connections_for_teardown__: + return + + self.checkin_all() + + # on PostgreSQL, this will test for any "idle in transaction" + # connections. useful to identify tests with unusual patterns + # that can't be cleaned up correctly. + from . import provision + + with config.db.connect() as conn: + provision.prepare_for_drop_tables(conn.engine.url, conn) + + def stop_test_class_inside_fixtures(self): + self.checkin_all() + self._drop_testing_engines("function") + self._drop_testing_engines("class") + + def stop_test_class_outside_fixtures(self): + # ensure no refs to checked out connections at all. + + if pool.base._strong_ref_connection_records: + gc_collect() + + if pool.base._strong_ref_connection_records: + ln = len(pool.base._strong_ref_connection_records) + pool.base._strong_ref_connection_records.clear() + assert ( + False + ), "%d connection recs not cleared after test suite" % (ln) + + def final_cleanup(self): + self.checkin_all() + for scope in self.testing_engines: + self._drop_testing_engines(scope) + + def assert_all_closed(self): + for rec in self.proxy_refs: + if rec.is_valid: + assert False + + +testing_reaper = ConnectionKiller() + + +@decorator +def assert_conns_closed(fn, *args, **kw): + try: + fn(*args, **kw) + finally: + testing_reaper.assert_all_closed() + + +@decorator +def rollback_open_connections(fn, *args, **kw): + """Decorator that rolls back all open connections after fn execution.""" + + try: + fn(*args, **kw) + finally: + testing_reaper.rollback_all() + + +@decorator +def close_first(fn, *args, **kw): + """Decorator that closes all connections before fn execution.""" + + testing_reaper.checkin_all() + fn(*args, **kw) + + +@decorator +def close_open_connections(fn, *args, **kw): + """Decorator that closes all connections after fn execution.""" + try: + fn(*args, **kw) + finally: + testing_reaper.checkin_all() + + +def all_dialects(exclude=None): + import sqlalchemy.dialects as d + + for name in d.__all__: + # TEMPORARY + if exclude and name in exclude: + continue + mod = getattr(d, name, None) + if not mod: + mod = getattr( + __import__("sqlalchemy.dialects.%s" % name).dialects, name + ) + yield mod.dialect() + + +class ReconnectFixture: + def __init__(self, dbapi): + self.dbapi = dbapi + self.connections = [] + self.is_stopped = False + + def __getattr__(self, key): + return getattr(self.dbapi, key) + + def connect(self, *args, **kwargs): + + conn = self.dbapi.connect(*args, **kwargs) + if self.is_stopped: + self._safe(conn.close) + curs = conn.cursor() # should fail on Oracle etc. + # should fail for everything that didn't fail + # above, connection is closed + curs.execute("select 1") + assert False, "simulated connect failure didn't work" + else: + self.connections.append(conn) + return conn + + def _safe(self, fn): + try: + fn() + except Exception as e: + warnings.warn("ReconnectFixture couldn't close connection: %s" % e) + + def shutdown(self, stop=False): + # TODO: this doesn't cover all cases + # as nicely as we'd like, namely MySQLdb. + # would need to implement R. Brewer's + # proxy server idea to get better + # coverage. + self.is_stopped = stop + for c in list(self.connections): + self._safe(c.close) + self.connections = [] + + def restart(self): + self.is_stopped = False + + +def reconnecting_engine(url=None, options=None): + url = url or config.db.url + dbapi = config.db.dialect.dbapi + if not options: + options = {} + options["module"] = ReconnectFixture(dbapi) + engine = testing_engine(url, options) + _dispose = engine.dispose + + def dispose(): + engine.dialect.dbapi.shutdown() + engine.dialect.dbapi.is_stopped = False + _dispose() + + engine.test_shutdown = engine.dialect.dbapi.shutdown + engine.test_restart = engine.dialect.dbapi.restart + engine.dispose = dispose + return engine + + +def testing_engine( + url: Optional[URL] = None, + options: Optional[Dict[str, Any]] = None, + asyncio: Literal[False] = False, + transfer_staticpool: bool = False, +) -> Engine: + ... + + +def testing_engine( + url: Optional[URL] = None, + options: Optional[Dict[str, Any]] = None, + asyncio: Literal[True] = True, + transfer_staticpool: bool = False, +) -> AsyncEngine: + ... + + +def testing_engine( + url=None, + options=None, + asyncio=False, + transfer_staticpool=False, + share_pool=False, + _sqlite_savepoint=False, +): + if asyncio: + assert not _sqlite_savepoint + from sqlalchemy.ext.asyncio import ( + create_async_engine as create_engine, + ) + else: + from sqlalchemy import create_engine + from sqlalchemy.engine.url import make_url + + if not options: + use_reaper = True + scope = "function" + sqlite_savepoint = False + else: + use_reaper = options.pop("use_reaper", True) + scope = options.pop("scope", "function") + sqlite_savepoint = options.pop("sqlite_savepoint", False) + + url = url or config.db.url + + url = make_url(url) + if options is None: + if config.db is None or url.drivername == config.db.url.drivername: + options = config.db_opts + else: + options = {} + elif config.db is not None and url.drivername == config.db.url.drivername: + default_opt = config.db_opts.copy() + default_opt.update(options) + + engine = create_engine(url, **options) + + if sqlite_savepoint and engine.name == "sqlite": + # apply SQLite savepoint workaround + @event.listens_for(engine, "connect") + def do_connect(dbapi_connection, connection_record): + dbapi_connection.isolation_level = None + + @event.listens_for(engine, "begin") + def do_begin(conn): + conn.exec_driver_sql("BEGIN") + + if transfer_staticpool: + from sqlalchemy.pool import StaticPool + + if config.db is not None and isinstance(config.db.pool, StaticPool): + use_reaper = False + engine.pool._transfer_from(config.db.pool) + elif share_pool: + engine.pool = config.db.pool + + if scope == "global": + if asyncio: + engine.sync_engine._has_events = True + else: + engine._has_events = ( + True # enable event blocks, helps with profiling + ) + + if isinstance(engine.pool, pool.QueuePool): + engine.pool._timeout = 0 + engine.pool._max_overflow = 0 + if use_reaper: + testing_reaper.add_engine(engine, scope) + + return engine + + +def mock_engine(dialect_name=None): + """Provides a mocking engine based on the current testing.db. + + This is normally used to test DDL generation flow as emitted + by an Engine. + + It should not be used in other cases, as assert_compile() and + assert_sql_execution() are much better choices with fewer + moving parts. + + """ + + from sqlalchemy import create_mock_engine + + if not dialect_name: + dialect_name = config.db.name + + buffer = [] + + def executor(sql, *a, **kw): + buffer.append(sql) + + def assert_sql(stmts): + recv = [re.sub(r"[\n\t]", "", str(s)) for s in buffer] + assert recv == stmts, recv + + def print_sql(): + d = engine.dialect + return "\n".join(str(s.compile(dialect=d)) for s in engine.mock) + + engine = create_mock_engine(dialect_name + "://", executor) + assert not hasattr(engine, "mock") + engine.mock = buffer + engine.assert_sql = assert_sql + engine.print_sql = print_sql + return engine + + +class DBAPIProxyCursor: + """Proxy a DBAPI cursor. + + Tests can provide subclasses of this to intercept + DBAPI-level cursor operations. + + """ + + def __init__(self, engine, conn, *args, **kwargs): + self.engine = engine + self.connection = conn + self.cursor = conn.cursor(*args, **kwargs) + + def execute(self, stmt, parameters=None, **kw): + if parameters: + return self.cursor.execute(stmt, parameters, **kw) + else: + return self.cursor.execute(stmt, **kw) + + def executemany(self, stmt, params, **kw): + return self.cursor.executemany(stmt, params, **kw) + + def __iter__(self): + return iter(self.cursor) + + def __getattr__(self, key): + return getattr(self.cursor, key) + + +class DBAPIProxyConnection: + """Proxy a DBAPI connection. + + Tests can provide subclasses of this to intercept + DBAPI-level connection operations. + + """ + + def __init__(self, engine, conn, cursor_cls): + self.conn = conn + self.engine = engine + self.cursor_cls = cursor_cls + + def cursor(self, *args, **kwargs): + return self.cursor_cls(self.engine, self.conn, *args, **kwargs) + + def close(self): + self.conn.close() + + def __getattr__(self, key): + return getattr(self.conn, key) |