diff options
Diffstat (limited to 'libs/alembic/operations/schemaobj.py')
-rw-r--r-- | libs/alembic/operations/schemaobj.py | 284 |
1 files changed, 284 insertions, 0 deletions
diff --git a/libs/alembic/operations/schemaobj.py b/libs/alembic/operations/schemaobj.py new file mode 100644 index 000000000..0568471a7 --- /dev/null +++ b/libs/alembic/operations/schemaobj.py @@ -0,0 +1,284 @@ +from __future__ import annotations + +from typing import Any +from typing import Dict +from typing import List +from typing import Optional +from typing import Sequence +from typing import Tuple +from typing import TYPE_CHECKING +from typing import Union + +from sqlalchemy import schema as sa_schema +from sqlalchemy.sql.schema import Column +from sqlalchemy.sql.schema import Constraint +from sqlalchemy.sql.schema import Index +from sqlalchemy.types import Integer +from sqlalchemy.types import NULLTYPE + +from .. import util +from ..util import sqla_compat + +if TYPE_CHECKING: + from sqlalchemy.sql.elements import ColumnElement + from sqlalchemy.sql.elements import TextClause + from sqlalchemy.sql.schema import CheckConstraint + from sqlalchemy.sql.schema import ForeignKey + from sqlalchemy.sql.schema import ForeignKeyConstraint + from sqlalchemy.sql.schema import MetaData + from sqlalchemy.sql.schema import PrimaryKeyConstraint + from sqlalchemy.sql.schema import Table + from sqlalchemy.sql.schema import UniqueConstraint + from sqlalchemy.sql.type_api import TypeEngine + + from ..runtime.migration import MigrationContext + + +class SchemaObjects: + def __init__( + self, migration_context: Optional[MigrationContext] = None + ) -> None: + self.migration_context = migration_context + + def primary_key_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + cols: Sequence[str], + schema: Optional[str] = None, + **dialect_kw, + ) -> PrimaryKeyConstraint: + m = self.metadata() + columns = [sa_schema.Column(n, NULLTYPE) for n in cols] + t = sa_schema.Table(table_name, m, *columns, schema=schema) + # SQLAlchemy primary key constraint name arg is wrongly typed on + # the SQLAlchemy side through 2.0.5 at least + p = sa_schema.PrimaryKeyConstraint( + *[t.c[n] for n in cols], name=name, **dialect_kw # type: ignore + ) + return p + + def foreign_key_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + source: str, + referent: str, + local_cols: List[str], + remote_cols: List[str], + onupdate: Optional[str] = None, + ondelete: Optional[str] = None, + deferrable: Optional[bool] = None, + source_schema: Optional[str] = None, + referent_schema: Optional[str] = None, + initially: Optional[str] = None, + match: Optional[str] = None, + **dialect_kw, + ) -> ForeignKeyConstraint: + m = self.metadata() + if source == referent and source_schema == referent_schema: + t1_cols = local_cols + remote_cols + else: + t1_cols = local_cols + sa_schema.Table( + referent, + m, + *[sa_schema.Column(n, NULLTYPE) for n in remote_cols], + schema=referent_schema, + ) + + t1 = sa_schema.Table( + source, + m, + *[sa_schema.Column(n, NULLTYPE) for n in t1_cols], + schema=source_schema, + ) + + tname = ( + "%s.%s" % (referent_schema, referent) + if referent_schema + else referent + ) + + dialect_kw["match"] = match + + f = sa_schema.ForeignKeyConstraint( + local_cols, + ["%s.%s" % (tname, n) for n in remote_cols], + name=name, + onupdate=onupdate, + ondelete=ondelete, + deferrable=deferrable, + initially=initially, + **dialect_kw, + ) + t1.append_constraint(f) + + return f + + def unique_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + source: str, + local_cols: Sequence[str], + schema: Optional[str] = None, + **kw, + ) -> UniqueConstraint: + t = sa_schema.Table( + source, + self.metadata(), + *[sa_schema.Column(n, NULLTYPE) for n in local_cols], + schema=schema, + ) + kw["name"] = name + uq = sa_schema.UniqueConstraint(*[t.c[n] for n in local_cols], **kw) + # TODO: need event tests to ensure the event + # is fired off here + t.append_constraint(uq) + return uq + + def check_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + source: str, + condition: Union[str, TextClause, ColumnElement[Any]], + schema: Optional[str] = None, + **kw, + ) -> Union[CheckConstraint]: + t = sa_schema.Table( + source, + self.metadata(), + sa_schema.Column("x", Integer), + schema=schema, + ) + ck = sa_schema.CheckConstraint(condition, name=name, **kw) + t.append_constraint(ck) + return ck + + def generic_constraint( + self, + name: Optional[sqla_compat._ConstraintNameDefined], + table_name: str, + type_: Optional[str], + schema: Optional[str] = None, + **kw, + ) -> Any: + t = self.table(table_name, schema=schema) + types: Dict[Optional[str], Any] = { + "foreignkey": lambda name: sa_schema.ForeignKeyConstraint( + [], [], name=name + ), + "primary": sa_schema.PrimaryKeyConstraint, + "unique": sa_schema.UniqueConstraint, + "check": lambda name: sa_schema.CheckConstraint("", name=name), + None: sa_schema.Constraint, + } + try: + const = types[type_] + except KeyError as ke: + raise TypeError( + "'type' can be one of %s" + % ", ".join(sorted(repr(x) for x in types)) + ) from ke + else: + const = const(name=name) + t.append_constraint(const) + return const + + def metadata(self) -> MetaData: + kw = {} + if ( + self.migration_context is not None + and "target_metadata" in self.migration_context.opts + ): + mt = self.migration_context.opts["target_metadata"] + if hasattr(mt, "naming_convention"): + kw["naming_convention"] = mt.naming_convention + return sa_schema.MetaData(**kw) + + def table(self, name: str, *columns, **kw) -> Table: + m = self.metadata() + + cols = [ + sqla_compat._copy(c) if c.table is not None else c + for c in columns + if isinstance(c, Column) + ] + # these flags have already added their UniqueConstraint / + # Index objects to the table, so flip them off here. + # SQLAlchemy tometadata() avoids this instead by preserving the + # flags and skipping the constraints that have _type_bound on them, + # but for a migration we'd rather list out the constraints + # explicitly. + _constraints_included = kw.pop("_constraints_included", False) + if _constraints_included: + for c in cols: + c.unique = c.index = False + + t = sa_schema.Table(name, m, *cols, **kw) + + constraints = [ + sqla_compat._copy(elem, target_table=t) + if getattr(elem, "parent", None) is not t + and getattr(elem, "parent", None) is not None + else elem + for elem in columns + if isinstance(elem, (Constraint, Index)) + ] + + for const in constraints: + t.append_constraint(const) + + for f in t.foreign_keys: + self._ensure_table_for_fk(m, f) + return t + + def column(self, name: str, type_: TypeEngine, **kw) -> Column: + return sa_schema.Column(name, type_, **kw) + + def index( + self, + name: Optional[str], + tablename: Optional[str], + columns: Sequence[Union[str, TextClause, ColumnElement[Any]]], + schema: Optional[str] = None, + **kw, + ) -> Index: + t = sa_schema.Table( + tablename or "no_table", + self.metadata(), + schema=schema, + ) + kw["_table"] = t + idx = sa_schema.Index( + name, + *[util.sqla_compat._textual_index_column(t, n) for n in columns], + **kw, + ) + return idx + + def _parse_table_key(self, table_key: str) -> Tuple[Optional[str], str]: + if "." in table_key: + tokens = table_key.split(".") + sname: Optional[str] = ".".join(tokens[0:-1]) + tname = tokens[-1] + else: + tname = table_key + sname = None + return (sname, tname) + + def _ensure_table_for_fk(self, metadata: MetaData, fk: ForeignKey) -> None: + """create a placeholder Table object for the referent of a + ForeignKey. + + """ + if isinstance(fk._colspec, str): # type:ignore[attr-defined] + table_key, cname = fk._colspec.rsplit( # type:ignore[attr-defined] + ".", 1 + ) + sname, tname = self._parse_table_key(table_key) + if table_key not in metadata.tables: + rel_t = sa_schema.Table(tname, metadata, schema=sname) + else: + rel_t = metadata.tables[table_key] + if cname not in rel_t.c: + rel_t.append_column(sa_schema.Column(cname, NULLTYPE)) |