diff options
Diffstat (limited to 'libs/alembic/util/sqla_compat.py')
-rw-r--r-- | libs/alembic/util/sqla_compat.py | 607 |
1 files changed, 607 insertions, 0 deletions
diff --git a/libs/alembic/util/sqla_compat.py b/libs/alembic/util/sqla_compat.py new file mode 100644 index 000000000..cab99494b --- /dev/null +++ b/libs/alembic/util/sqla_compat.py @@ -0,0 +1,607 @@ +from __future__ import annotations + +import contextlib +import re +from typing import Any +from typing import Iterable +from typing import Iterator +from typing import Mapping +from typing import Optional +from typing import TYPE_CHECKING +from typing import TypeVar +from typing import Union + +from sqlalchemy import __version__ +from sqlalchemy import inspect +from sqlalchemy import schema +from sqlalchemy import sql +from sqlalchemy import types as sqltypes +from sqlalchemy.engine import url +from sqlalchemy.ext.compiler import compiles +from sqlalchemy.schema import CheckConstraint +from sqlalchemy.schema import Column +from sqlalchemy.schema import ForeignKeyConstraint +from sqlalchemy.sql import visitors +from sqlalchemy.sql.elements import BindParameter +from sqlalchemy.sql.elements import ColumnClause +from sqlalchemy.sql.elements import quoted_name +from sqlalchemy.sql.elements import TextClause +from sqlalchemy.sql.elements import UnaryExpression +from sqlalchemy.sql.visitors import traverse +from typing_extensions import TypeGuard + +if TYPE_CHECKING: + from sqlalchemy import Index + from sqlalchemy import Table + from sqlalchemy.engine import Connection + from sqlalchemy.engine import Dialect + from sqlalchemy.engine import Transaction + from sqlalchemy.engine.reflection import Inspector + from sqlalchemy.sql.base import ColumnCollection + from sqlalchemy.sql.compiler import SQLCompiler + from sqlalchemy.sql.dml import Insert + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.schema import Constraint + from sqlalchemy.sql.schema import SchemaItem + from sqlalchemy.sql.selectable import Select + from sqlalchemy.sql.selectable import TableClause + +_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"]) + + +def _safe_int(value: str) -> Union[int, str]: + try: + return int(value) + except: + return value + + +_vers = tuple( + [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)] +) +sqla_13 = _vers >= (1, 3) +sqla_14 = _vers >= (1, 4) +sqla_14_26 = _vers >= (1, 4, 26) +sqla_2 = _vers >= (2,) +sqlalchemy_version = __version__ + +try: + from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME +except ImportError: + from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501 + + +if sqla_14: + # when future engine merges, this can be again based on version string + from sqlalchemy.engine import Connection as legacy_connection + + sqla_1x = not hasattr(legacy_connection, "commit") +else: + sqla_1x = True + +try: + from sqlalchemy import Computed # noqa +except ImportError: + Computed = type(None) # type: ignore + has_computed = False + has_computed_reflection = False +else: + has_computed = True + has_computed_reflection = _vers >= (1, 3, 16) + +try: + from sqlalchemy import Identity # noqa +except ImportError: + Identity = type(None) # type: ignore + has_identity = False +else: + # attributes common to Indentity and Sequence + _identity_options_attrs = ( + "start", + "increment", + "minvalue", + "maxvalue", + "nominvalue", + "nomaxvalue", + "cycle", + "cache", + "order", + ) + # attributes of Indentity + _identity_attrs = _identity_options_attrs + ("on_null",) + has_identity = True + +if sqla_2: + from sqlalchemy.sql.base import _NoneName +else: + from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment] + + +_ConstraintName = Union[None, str, _NoneName] + +_ConstraintNameDefined = Union[str, _NoneName] + + +def constraint_name_defined( + name: _ConstraintName, +) -> TypeGuard[_ConstraintNameDefined]: + return name is _NONE_NAME or isinstance(name, (str, _NoneName)) + + +def constraint_name_string( + name: _ConstraintName, +) -> TypeGuard[str]: + return isinstance(name, str) + + +def constraint_name_or_none( + name: _ConstraintName, +) -> Optional[str]: + return name if constraint_name_string(name) else None + + +AUTOINCREMENT_DEFAULT = "auto" + + +def _ensure_scope_for_ddl( + connection: Optional[Connection], +) -> Iterator[None]: + try: + in_transaction = connection.in_transaction # type: ignore[union-attr] + except AttributeError: + # catch for MockConnection, None + in_transaction = None + pass + + # yield outside the catch + if in_transaction is None: + yield + else: + if not in_transaction(): + assert connection is not None + with connection.begin(): + yield + else: + yield + + +def url_render_as_string(url, hide_password=True): + if sqla_14: + return url.render_as_string(hide_password=hide_password) + else: + return url.__to_string__(hide_password=hide_password) + + +def _safe_begin_connection_transaction( + connection: Connection, +) -> Transaction: + transaction = _get_connection_transaction(connection) + if transaction: + return transaction + else: + return connection.begin() + + +def _safe_commit_connection_transaction( + connection: Connection, +) -> None: + transaction = _get_connection_transaction(connection) + if transaction: + transaction.commit() + + +def _safe_rollback_connection_transaction( + connection: Connection, +) -> None: + transaction = _get_connection_transaction(connection) + if transaction: + transaction.rollback() + + +def _get_connection_in_transaction(connection: Optional[Connection]) -> bool: + try: + in_transaction = connection.in_transaction # type: ignore + except AttributeError: + # catch for MockConnection + return False + else: + return in_transaction() + + +def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]: + return idx.expressions # type: ignore + + +def _copy(schema_item: _CE, **kw) -> _CE: + if hasattr(schema_item, "_copy"): + return schema_item._copy(**kw) # type: ignore[union-attr] + else: + return schema_item.copy(**kw) # type: ignore[union-attr] + + +def _get_connection_transaction( + connection: Connection, +) -> Optional[Transaction]: + if sqla_14: + return connection.get_transaction() + else: + r = connection._root # type: ignore[attr-defined] + return r._Connection__transaction + + +def _create_url(*arg, **kw) -> url.URL: + if hasattr(url.URL, "create"): + return url.URL.create(*arg, **kw) + else: + return url.URL(*arg, **kw) + + +def _connectable_has_table( + connectable: Connection, tablename: str, schemaname: Union[str, None] +) -> bool: + if sqla_14: + return inspect(connectable).has_table(tablename, schemaname) + else: + return connectable.dialect.has_table( + connectable, tablename, schemaname + ) + + +def _exec_on_inspector(inspector, statement, **params): + if sqla_14: + with inspector._operation_context() as conn: + return conn.execute(statement, params) + else: + return inspector.bind.execute(statement, params) + + +def _nullability_might_be_unset(metadata_column): + if not sqla_14: + return metadata_column.nullable + else: + from sqlalchemy.sql import schema + + return ( + metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED + ) + + +def _server_default_is_computed(*server_default) -> bool: + if not has_computed: + return False + else: + return any(isinstance(sd, Computed) for sd in server_default) + + +def _server_default_is_identity(*server_default) -> bool: + if not sqla_14: + return False + else: + return any(isinstance(sd, Identity) for sd in server_default) + + +def _table_for_constraint(constraint: Constraint) -> Table: + if isinstance(constraint, ForeignKeyConstraint): + table = constraint.parent + assert table is not None + return table # type: ignore[return-value] + else: + return constraint.table + + +def _columns_for_constraint(constraint): + if isinstance(constraint, ForeignKeyConstraint): + return [fk.parent for fk in constraint.elements] + elif isinstance(constraint, CheckConstraint): + return _find_columns(constraint.sqltext) + else: + return list(constraint.columns) + + +def _reflect_table( + inspector: Inspector, table: Table, include_cols: None +) -> None: + if sqla_14: + return inspector.reflect_table(table, None) + else: + return inspector.reflecttable( # type: ignore[attr-defined] + table, None + ) + + +def _resolve_for_variant(type_, dialect): + if _type_has_variants(type_): + base_type, mapping = _get_variant_mapping(type_) + return mapping.get(dialect.name, base_type) + else: + return type_ + + +if hasattr(sqltypes.TypeEngine, "_variant_mapping"): + + def _type_has_variants(type_): + return bool(type_._variant_mapping) + + def _get_variant_mapping(type_): + return type_, type_._variant_mapping + +else: + + def _type_has_variants(type_): + return type(type_) is sqltypes.Variant + + def _get_variant_mapping(type_): + return type_.impl, type_.mapping + + +def _fk_spec(constraint): + source_columns = [ + constraint.columns[key].name for key in constraint.column_keys + ] + + source_table = constraint.parent.name + source_schema = constraint.parent.schema + target_schema = constraint.elements[0].column.table.schema + target_table = constraint.elements[0].column.table.name + target_columns = [element.column.name for element in constraint.elements] + ondelete = constraint.ondelete + onupdate = constraint.onupdate + deferrable = constraint.deferrable + initially = constraint.initially + return ( + source_schema, + source_table, + source_columns, + target_schema, + target_table, + target_columns, + onupdate, + ondelete, + deferrable, + initially, + ) + + +def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool: + spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined] + tokens = spec.split(".") + tokens.pop(-1) # colname + tablekey = ".".join(tokens) + assert constraint.parent is not None + return tablekey == constraint.parent.key + + +def _is_type_bound(constraint: Constraint) -> bool: + # this deals with SQLAlchemy #3260, don't copy CHECK constraints + # that will be generated by the type. + # new feature added for #3260 + return constraint._type_bound # type: ignore[attr-defined] + + +def _find_columns(clause): + """locate Column objects within the given expression.""" + + cols = set() + traverse(clause, {}, {"column": cols.add}) + return cols + + +def _remove_column_from_collection( + collection: ColumnCollection, column: Union[Column, ColumnClause] +) -> None: + """remove a column from a ColumnCollection.""" + + # workaround for older SQLAlchemy, remove the + # same object that's present + assert column.key is not None + to_remove = collection[column.key] + + # SQLAlchemy 2.0 will use more ReadOnlyColumnCollection + # (renamed from ImmutableColumnCollection) + if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"): + collection._parent.remove(to_remove) + else: + collection.remove(to_remove) + + +def _textual_index_column( + table: Table, text_: Union[str, TextClause, ColumnElement] +) -> Union[ColumnElement, Column]: + """a workaround for the Index construct's severe lack of flexibility""" + if isinstance(text_, str): + c = Column(text_, sqltypes.NULLTYPE) + table.append_column(c) + return c + elif isinstance(text_, TextClause): + return _textual_index_element(table, text_) + elif isinstance(text_, _textual_index_element): + return _textual_index_column(table, text_.text) + elif isinstance(text_, sql.ColumnElement): + return _copy_expression(text_, table) + else: + raise ValueError("String or text() construct expected") + + +def _copy_expression(expression: _CE, target_table: Table) -> _CE: + def replace(col): + if ( + isinstance(col, Column) + and col.table is not None + and col.table is not target_table + ): + if col.name in target_table.c: + return target_table.c[col.name] + else: + c = _copy(col) + target_table.append_column(c) + return c + else: + return None + + return visitors.replacement_traverse( # type: ignore[call-overload] + expression, {}, replace + ) + + +class _textual_index_element(sql.ColumnElement): + """Wrap around a sqlalchemy text() construct in such a way that + we appear like a column-oriented SQL expression to an Index + construct. + + The issue here is that currently the Postgresql dialect, the biggest + recipient of functional indexes, keys all the index expressions to + the corresponding column expressions when rendering CREATE INDEX, + so the Index we create here needs to have a .columns collection that + is the same length as the .expressions collection. Ultimately + SQLAlchemy should support text() expressions in indexes. + + See SQLAlchemy issue 3174. + + """ + + __visit_name__ = "_textual_idx_element" + + def __init__(self, table: Table, text: TextClause) -> None: + self.table = table + self.text = text + self.key = text.text + self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE) + table.append_column(self.fake_column) + + def get_children(self): + return [self.fake_column] + + +@compiles(_textual_index_element) +def _render_textual_index_column( + element: _textual_index_element, compiler: SQLCompiler, **kw +) -> str: + return compiler.process(element.text, **kw) + + +class _literal_bindparam(BindParameter): + pass + + +@compiles(_literal_bindparam) +def _render_literal_bindparam( + element: _literal_bindparam, compiler: SQLCompiler, **kw +) -> str: + return compiler.render_literal_bindparam(element, **kw) + + +def _get_index_expressions(idx): + return list(idx.expressions) + + +def _get_index_column_names(idx): + return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)] + + +def _column_kwargs(col: Column) -> Mapping: + if sqla_13: + return col.kwargs + else: + return {} + + +def _get_constraint_final_name( + constraint: Union[Index, Constraint], dialect: Optional[Dialect] +) -> Optional[str]: + if constraint.name is None: + return None + assert dialect is not None + if sqla_14: + # for SQLAlchemy 1.4 we would like to have the option to expand + # the use of "deferred" names for constraints as well as to have + # some flexibility with "None" name and similar; make use of new + # SQLAlchemy API to return what would be the final compiled form of + # the name for this dialect. + return dialect.identifier_preparer.format_constraint( + constraint, _alembic_quote=False + ) + else: + + # prior to SQLAlchemy 1.4, work around quoting logic to get at the + # final compiled name without quotes. + if hasattr(constraint.name, "quote"): + # might be quoted_name, might be truncated_name, keep it the + # same + quoted_name_cls: type = type(constraint.name) + else: + quoted_name_cls = quoted_name + + new_name = quoted_name_cls(str(constraint.name), quote=False) + constraint = constraint.__class__(name=new_name) + + if isinstance(constraint, schema.Index): + # name should not be quoted. + d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type] + return d._prepared_index_name( # type: ignore[attr-defined] + constraint + ) + else: + # name should not be quoted. + return dialect.identifier_preparer.format_constraint(constraint) + + +def _constraint_is_named( + constraint: Union[Constraint, Index], dialect: Optional[Dialect] +) -> bool: + if sqla_14: + if constraint.name is None: + return False + assert dialect is not None + name = dialect.identifier_preparer.format_constraint( + constraint, _alembic_quote=False + ) + return name is not None + else: + return constraint.name is not None + + +def _is_mariadb(mysql_dialect: Dialect) -> bool: + if sqla_14: + return mysql_dialect.is_mariadb # type: ignore[attr-defined] + else: + return bool( + mysql_dialect.server_version_info + and mysql_dialect._is_mariadb # type: ignore[attr-defined] + ) + + +def _mariadb_normalized_version_info(mysql_dialect): + return mysql_dialect._mariadb_normalized_version_info + + +def _insert_inline(table: Union[TableClause, Table]) -> Insert: + if sqla_14: + return table.insert().inline() + else: + return table.insert(inline=True) # type: ignore[call-arg] + + +if sqla_14: + from sqlalchemy import create_mock_engine + from sqlalchemy import select as _select +else: + from sqlalchemy import create_engine + + def create_mock_engine(url, executor, **kw): # type: ignore[misc] + return create_engine( + "postgresql://", strategy="mock", executor=executor + ) + + def _select(*columns, **kw) -> Select: # type: ignore[no-redef] + return sql.select(list(columns), **kw) # type: ignore[call-overload] + + +def is_expression_index(index: Index) -> bool: + expr: Any + for expr in index.expressions: + while isinstance(expr, UnaryExpression): + expr = expr.element + if not isinstance(expr, ColumnClause) or expr.is_literal: + return True + return False |