summaryrefslogtreecommitdiffhomepage
path: root/libs/alembic/util/sqla_compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'libs/alembic/util/sqla_compat.py')
-rw-r--r--libs/alembic/util/sqla_compat.py607
1 files changed, 607 insertions, 0 deletions
diff --git a/libs/alembic/util/sqla_compat.py b/libs/alembic/util/sqla_compat.py
new file mode 100644
index 000000000..cab99494b
--- /dev/null
+++ b/libs/alembic/util/sqla_compat.py
@@ -0,0 +1,607 @@
+from __future__ import annotations
+
+import contextlib
+import re
+from typing import Any
+from typing import Iterable
+from typing import Iterator
+from typing import Mapping
+from typing import Optional
+from typing import TYPE_CHECKING
+from typing import TypeVar
+from typing import Union
+
+from sqlalchemy import __version__
+from sqlalchemy import inspect
+from sqlalchemy import schema
+from sqlalchemy import sql
+from sqlalchemy import types as sqltypes
+from sqlalchemy.engine import url
+from sqlalchemy.ext.compiler import compiles
+from sqlalchemy.schema import CheckConstraint
+from sqlalchemy.schema import Column
+from sqlalchemy.schema import ForeignKeyConstraint
+from sqlalchemy.sql import visitors
+from sqlalchemy.sql.elements import BindParameter
+from sqlalchemy.sql.elements import ColumnClause
+from sqlalchemy.sql.elements import quoted_name
+from sqlalchemy.sql.elements import TextClause
+from sqlalchemy.sql.elements import UnaryExpression
+from sqlalchemy.sql.visitors import traverse
+from typing_extensions import TypeGuard
+
+if TYPE_CHECKING:
+ from sqlalchemy import Index
+ from sqlalchemy import Table
+ from sqlalchemy.engine import Connection
+ from sqlalchemy.engine import Dialect
+ from sqlalchemy.engine import Transaction
+ from sqlalchemy.engine.reflection import Inspector
+ from sqlalchemy.sql.base import ColumnCollection
+ from sqlalchemy.sql.compiler import SQLCompiler
+ from sqlalchemy.sql.dml import Insert
+ from sqlalchemy.sql.elements import ColumnElement
+ from sqlalchemy.sql.schema import Constraint
+ from sqlalchemy.sql.schema import SchemaItem
+ from sqlalchemy.sql.selectable import Select
+ from sqlalchemy.sql.selectable import TableClause
+
+_CE = TypeVar("_CE", bound=Union["ColumnElement", "SchemaItem"])
+
+
+def _safe_int(value: str) -> Union[int, str]:
+ try:
+ return int(value)
+ except:
+ return value
+
+
+_vers = tuple(
+ [_safe_int(x) for x in re.findall(r"(\d+|[abc]\d)", __version__)]
+)
+sqla_13 = _vers >= (1, 3)
+sqla_14 = _vers >= (1, 4)
+sqla_14_26 = _vers >= (1, 4, 26)
+sqla_2 = _vers >= (2,)
+sqlalchemy_version = __version__
+
+try:
+ from sqlalchemy.sql.naming import _NONE_NAME as _NONE_NAME
+except ImportError:
+ from sqlalchemy.sql.elements import _NONE_NAME as _NONE_NAME # type: ignore # noqa: E501
+
+
+if sqla_14:
+ # when future engine merges, this can be again based on version string
+ from sqlalchemy.engine import Connection as legacy_connection
+
+ sqla_1x = not hasattr(legacy_connection, "commit")
+else:
+ sqla_1x = True
+
+try:
+ from sqlalchemy import Computed # noqa
+except ImportError:
+ Computed = type(None) # type: ignore
+ has_computed = False
+ has_computed_reflection = False
+else:
+ has_computed = True
+ has_computed_reflection = _vers >= (1, 3, 16)
+
+try:
+ from sqlalchemy import Identity # noqa
+except ImportError:
+ Identity = type(None) # type: ignore
+ has_identity = False
+else:
+ # attributes common to Indentity and Sequence
+ _identity_options_attrs = (
+ "start",
+ "increment",
+ "minvalue",
+ "maxvalue",
+ "nominvalue",
+ "nomaxvalue",
+ "cycle",
+ "cache",
+ "order",
+ )
+ # attributes of Indentity
+ _identity_attrs = _identity_options_attrs + ("on_null",)
+ has_identity = True
+
+if sqla_2:
+ from sqlalchemy.sql.base import _NoneName
+else:
+ from sqlalchemy.util import symbol as _NoneName # type: ignore[assignment]
+
+
+_ConstraintName = Union[None, str, _NoneName]
+
+_ConstraintNameDefined = Union[str, _NoneName]
+
+
+def constraint_name_defined(
+ name: _ConstraintName,
+) -> TypeGuard[_ConstraintNameDefined]:
+ return name is _NONE_NAME or isinstance(name, (str, _NoneName))
+
+
+def constraint_name_string(
+ name: _ConstraintName,
+) -> TypeGuard[str]:
+ return isinstance(name, str)
+
+
+def constraint_name_or_none(
+ name: _ConstraintName,
+) -> Optional[str]:
+ return name if constraint_name_string(name) else None
+
+
+AUTOINCREMENT_DEFAULT = "auto"
+
+
+def _ensure_scope_for_ddl(
+ connection: Optional[Connection],
+) -> Iterator[None]:
+ try:
+ in_transaction = connection.in_transaction # type: ignore[union-attr]
+ except AttributeError:
+ # catch for MockConnection, None
+ in_transaction = None
+ pass
+
+ # yield outside the catch
+ if in_transaction is None:
+ yield
+ else:
+ if not in_transaction():
+ assert connection is not None
+ with connection.begin():
+ yield
+ else:
+ yield
+
+
+def url_render_as_string(url, hide_password=True):
+ if sqla_14:
+ return url.render_as_string(hide_password=hide_password)
+ else:
+ return url.__to_string__(hide_password=hide_password)
+
+
+def _safe_begin_connection_transaction(
+ connection: Connection,
+) -> Transaction:
+ transaction = _get_connection_transaction(connection)
+ if transaction:
+ return transaction
+ else:
+ return connection.begin()
+
+
+def _safe_commit_connection_transaction(
+ connection: Connection,
+) -> None:
+ transaction = _get_connection_transaction(connection)
+ if transaction:
+ transaction.commit()
+
+
+def _safe_rollback_connection_transaction(
+ connection: Connection,
+) -> None:
+ transaction = _get_connection_transaction(connection)
+ if transaction:
+ transaction.rollback()
+
+
+def _get_connection_in_transaction(connection: Optional[Connection]) -> bool:
+ try:
+ in_transaction = connection.in_transaction # type: ignore
+ except AttributeError:
+ # catch for MockConnection
+ return False
+ else:
+ return in_transaction()
+
+
+def _idx_table_bound_expressions(idx: Index) -> Iterable[ColumnElement[Any]]:
+ return idx.expressions # type: ignore
+
+
+def _copy(schema_item: _CE, **kw) -> _CE:
+ if hasattr(schema_item, "_copy"):
+ return schema_item._copy(**kw) # type: ignore[union-attr]
+ else:
+ return schema_item.copy(**kw) # type: ignore[union-attr]
+
+
+def _get_connection_transaction(
+ connection: Connection,
+) -> Optional[Transaction]:
+ if sqla_14:
+ return connection.get_transaction()
+ else:
+ r = connection._root # type: ignore[attr-defined]
+ return r._Connection__transaction
+
+
+def _create_url(*arg, **kw) -> url.URL:
+ if hasattr(url.URL, "create"):
+ return url.URL.create(*arg, **kw)
+ else:
+ return url.URL(*arg, **kw)
+
+
+def _connectable_has_table(
+ connectable: Connection, tablename: str, schemaname: Union[str, None]
+) -> bool:
+ if sqla_14:
+ return inspect(connectable).has_table(tablename, schemaname)
+ else:
+ return connectable.dialect.has_table(
+ connectable, tablename, schemaname
+ )
+
+
+def _exec_on_inspector(inspector, statement, **params):
+ if sqla_14:
+ with inspector._operation_context() as conn:
+ return conn.execute(statement, params)
+ else:
+ return inspector.bind.execute(statement, params)
+
+
+def _nullability_might_be_unset(metadata_column):
+ if not sqla_14:
+ return metadata_column.nullable
+ else:
+ from sqlalchemy.sql import schema
+
+ return (
+ metadata_column._user_defined_nullable is schema.NULL_UNSPECIFIED
+ )
+
+
+def _server_default_is_computed(*server_default) -> bool:
+ if not has_computed:
+ return False
+ else:
+ return any(isinstance(sd, Computed) for sd in server_default)
+
+
+def _server_default_is_identity(*server_default) -> bool:
+ if not sqla_14:
+ return False
+ else:
+ return any(isinstance(sd, Identity) for sd in server_default)
+
+
+def _table_for_constraint(constraint: Constraint) -> Table:
+ if isinstance(constraint, ForeignKeyConstraint):
+ table = constraint.parent
+ assert table is not None
+ return table # type: ignore[return-value]
+ else:
+ return constraint.table
+
+
+def _columns_for_constraint(constraint):
+ if isinstance(constraint, ForeignKeyConstraint):
+ return [fk.parent for fk in constraint.elements]
+ elif isinstance(constraint, CheckConstraint):
+ return _find_columns(constraint.sqltext)
+ else:
+ return list(constraint.columns)
+
+
+def _reflect_table(
+ inspector: Inspector, table: Table, include_cols: None
+) -> None:
+ if sqla_14:
+ return inspector.reflect_table(table, None)
+ else:
+ return inspector.reflecttable( # type: ignore[attr-defined]
+ table, None
+ )
+
+
+def _resolve_for_variant(type_, dialect):
+ if _type_has_variants(type_):
+ base_type, mapping = _get_variant_mapping(type_)
+ return mapping.get(dialect.name, base_type)
+ else:
+ return type_
+
+
+if hasattr(sqltypes.TypeEngine, "_variant_mapping"):
+
+ def _type_has_variants(type_):
+ return bool(type_._variant_mapping)
+
+ def _get_variant_mapping(type_):
+ return type_, type_._variant_mapping
+
+else:
+
+ def _type_has_variants(type_):
+ return type(type_) is sqltypes.Variant
+
+ def _get_variant_mapping(type_):
+ return type_.impl, type_.mapping
+
+
+def _fk_spec(constraint):
+ source_columns = [
+ constraint.columns[key].name for key in constraint.column_keys
+ ]
+
+ source_table = constraint.parent.name
+ source_schema = constraint.parent.schema
+ target_schema = constraint.elements[0].column.table.schema
+ target_table = constraint.elements[0].column.table.name
+ target_columns = [element.column.name for element in constraint.elements]
+ ondelete = constraint.ondelete
+ onupdate = constraint.onupdate
+ deferrable = constraint.deferrable
+ initially = constraint.initially
+ return (
+ source_schema,
+ source_table,
+ source_columns,
+ target_schema,
+ target_table,
+ target_columns,
+ onupdate,
+ ondelete,
+ deferrable,
+ initially,
+ )
+
+
+def _fk_is_self_referential(constraint: ForeignKeyConstraint) -> bool:
+ spec = constraint.elements[0]._get_colspec() # type: ignore[attr-defined]
+ tokens = spec.split(".")
+ tokens.pop(-1) # colname
+ tablekey = ".".join(tokens)
+ assert constraint.parent is not None
+ return tablekey == constraint.parent.key
+
+
+def _is_type_bound(constraint: Constraint) -> bool:
+ # this deals with SQLAlchemy #3260, don't copy CHECK constraints
+ # that will be generated by the type.
+ # new feature added for #3260
+ return constraint._type_bound # type: ignore[attr-defined]
+
+
+def _find_columns(clause):
+ """locate Column objects within the given expression."""
+
+ cols = set()
+ traverse(clause, {}, {"column": cols.add})
+ return cols
+
+
+def _remove_column_from_collection(
+ collection: ColumnCollection, column: Union[Column, ColumnClause]
+) -> None:
+ """remove a column from a ColumnCollection."""
+
+ # workaround for older SQLAlchemy, remove the
+ # same object that's present
+ assert column.key is not None
+ to_remove = collection[column.key]
+
+ # SQLAlchemy 2.0 will use more ReadOnlyColumnCollection
+ # (renamed from ImmutableColumnCollection)
+ if hasattr(collection, "_immutable") or hasattr(collection, "_readonly"):
+ collection._parent.remove(to_remove)
+ else:
+ collection.remove(to_remove)
+
+
+def _textual_index_column(
+ table: Table, text_: Union[str, TextClause, ColumnElement]
+) -> Union[ColumnElement, Column]:
+ """a workaround for the Index construct's severe lack of flexibility"""
+ if isinstance(text_, str):
+ c = Column(text_, sqltypes.NULLTYPE)
+ table.append_column(c)
+ return c
+ elif isinstance(text_, TextClause):
+ return _textual_index_element(table, text_)
+ elif isinstance(text_, _textual_index_element):
+ return _textual_index_column(table, text_.text)
+ elif isinstance(text_, sql.ColumnElement):
+ return _copy_expression(text_, table)
+ else:
+ raise ValueError("String or text() construct expected")
+
+
+def _copy_expression(expression: _CE, target_table: Table) -> _CE:
+ def replace(col):
+ if (
+ isinstance(col, Column)
+ and col.table is not None
+ and col.table is not target_table
+ ):
+ if col.name in target_table.c:
+ return target_table.c[col.name]
+ else:
+ c = _copy(col)
+ target_table.append_column(c)
+ return c
+ else:
+ return None
+
+ return visitors.replacement_traverse( # type: ignore[call-overload]
+ expression, {}, replace
+ )
+
+
+class _textual_index_element(sql.ColumnElement):
+ """Wrap around a sqlalchemy text() construct in such a way that
+ we appear like a column-oriented SQL expression to an Index
+ construct.
+
+ The issue here is that currently the Postgresql dialect, the biggest
+ recipient of functional indexes, keys all the index expressions to
+ the corresponding column expressions when rendering CREATE INDEX,
+ so the Index we create here needs to have a .columns collection that
+ is the same length as the .expressions collection. Ultimately
+ SQLAlchemy should support text() expressions in indexes.
+
+ See SQLAlchemy issue 3174.
+
+ """
+
+ __visit_name__ = "_textual_idx_element"
+
+ def __init__(self, table: Table, text: TextClause) -> None:
+ self.table = table
+ self.text = text
+ self.key = text.text
+ self.fake_column = schema.Column(self.text.text, sqltypes.NULLTYPE)
+ table.append_column(self.fake_column)
+
+ def get_children(self):
+ return [self.fake_column]
+
+
+@compiles(_textual_index_element)
+def _render_textual_index_column(
+ element: _textual_index_element, compiler: SQLCompiler, **kw
+) -> str:
+ return compiler.process(element.text, **kw)
+
+
+class _literal_bindparam(BindParameter):
+ pass
+
+
+@compiles(_literal_bindparam)
+def _render_literal_bindparam(
+ element: _literal_bindparam, compiler: SQLCompiler, **kw
+) -> str:
+ return compiler.render_literal_bindparam(element, **kw)
+
+
+def _get_index_expressions(idx):
+ return list(idx.expressions)
+
+
+def _get_index_column_names(idx):
+ return [getattr(exp, "name", None) for exp in _get_index_expressions(idx)]
+
+
+def _column_kwargs(col: Column) -> Mapping:
+ if sqla_13:
+ return col.kwargs
+ else:
+ return {}
+
+
+def _get_constraint_final_name(
+ constraint: Union[Index, Constraint], dialect: Optional[Dialect]
+) -> Optional[str]:
+ if constraint.name is None:
+ return None
+ assert dialect is not None
+ if sqla_14:
+ # for SQLAlchemy 1.4 we would like to have the option to expand
+ # the use of "deferred" names for constraints as well as to have
+ # some flexibility with "None" name and similar; make use of new
+ # SQLAlchemy API to return what would be the final compiled form of
+ # the name for this dialect.
+ return dialect.identifier_preparer.format_constraint(
+ constraint, _alembic_quote=False
+ )
+ else:
+
+ # prior to SQLAlchemy 1.4, work around quoting logic to get at the
+ # final compiled name without quotes.
+ if hasattr(constraint.name, "quote"):
+ # might be quoted_name, might be truncated_name, keep it the
+ # same
+ quoted_name_cls: type = type(constraint.name)
+ else:
+ quoted_name_cls = quoted_name
+
+ new_name = quoted_name_cls(str(constraint.name), quote=False)
+ constraint = constraint.__class__(name=new_name)
+
+ if isinstance(constraint, schema.Index):
+ # name should not be quoted.
+ d = dialect.ddl_compiler(dialect, None) # type: ignore[arg-type]
+ return d._prepared_index_name( # type: ignore[attr-defined]
+ constraint
+ )
+ else:
+ # name should not be quoted.
+ return dialect.identifier_preparer.format_constraint(constraint)
+
+
+def _constraint_is_named(
+ constraint: Union[Constraint, Index], dialect: Optional[Dialect]
+) -> bool:
+ if sqla_14:
+ if constraint.name is None:
+ return False
+ assert dialect is not None
+ name = dialect.identifier_preparer.format_constraint(
+ constraint, _alembic_quote=False
+ )
+ return name is not None
+ else:
+ return constraint.name is not None
+
+
+def _is_mariadb(mysql_dialect: Dialect) -> bool:
+ if sqla_14:
+ return mysql_dialect.is_mariadb # type: ignore[attr-defined]
+ else:
+ return bool(
+ mysql_dialect.server_version_info
+ and mysql_dialect._is_mariadb # type: ignore[attr-defined]
+ )
+
+
+def _mariadb_normalized_version_info(mysql_dialect):
+ return mysql_dialect._mariadb_normalized_version_info
+
+
+def _insert_inline(table: Union[TableClause, Table]) -> Insert:
+ if sqla_14:
+ return table.insert().inline()
+ else:
+ return table.insert(inline=True) # type: ignore[call-arg]
+
+
+if sqla_14:
+ from sqlalchemy import create_mock_engine
+ from sqlalchemy import select as _select
+else:
+ from sqlalchemy import create_engine
+
+ def create_mock_engine(url, executor, **kw): # type: ignore[misc]
+ return create_engine(
+ "postgresql://", strategy="mock", executor=executor
+ )
+
+ def _select(*columns, **kw) -> Select: # type: ignore[no-redef]
+ return sql.select(list(columns), **kw) # type: ignore[call-overload]
+
+
+def is_expression_index(index: Index) -> bool:
+ expr: Any
+ for expr in index.expressions:
+ while isinstance(expr, UnaryExpression):
+ expr = expr.element
+ if not isinstance(expr, ColumnClause) or expr.is_literal:
+ return True
+ return False