summaryrefslogtreecommitdiffhomepage
path: root/libs/alembic/ddl/oracle.py
blob: 9715c1e81a7e67ae741dbc02b75e99e41fc3c097 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
from __future__ import annotations

import re
from typing import Any
from typing import Optional
from typing import TYPE_CHECKING

from sqlalchemy.ext.compiler import compiles
from sqlalchemy.sql import sqltypes

from .base import AddColumn
from .base import alter_table
from .base import ColumnComment
from .base import ColumnDefault
from .base import ColumnName
from .base import ColumnNullable
from .base import ColumnType
from .base import format_column_name
from .base import format_server_default
from .base import format_table_name
from .base import format_type
from .base import IdentityColumnDefault
from .base import RenameTable
from .impl import DefaultImpl

if TYPE_CHECKING:
    from sqlalchemy.dialects.oracle.base import OracleDDLCompiler
    from sqlalchemy.engine.cursor import CursorResult
    from sqlalchemy.sql.schema import Column


class OracleImpl(DefaultImpl):
    __dialect__ = "oracle"
    transactional_ddl = False
    batch_separator = "/"
    command_terminator = ""
    type_synonyms = DefaultImpl.type_synonyms + (
        {"VARCHAR", "VARCHAR2"},
        {"BIGINT", "INTEGER", "SMALLINT", "DECIMAL", "NUMERIC", "NUMBER"},
        {"DOUBLE", "FLOAT", "DOUBLE_PRECISION"},
    )
    identity_attrs_ignore = ()

    def __init__(self, *arg, **kw) -> None:
        super().__init__(*arg, **kw)
        self.batch_separator = self.context_opts.get(
            "oracle_batch_separator", self.batch_separator
        )

    def _exec(self, construct: Any, *args, **kw) -> Optional[CursorResult]:
        result = super()._exec(construct, *args, **kw)
        if self.as_sql and self.batch_separator:
            self.static_output(self.batch_separator)
        return result

    def compare_server_default(
        self,
        inspector_column,
        metadata_column,
        rendered_metadata_default,
        rendered_inspector_default,
    ):
        if rendered_metadata_default is not None:
            rendered_metadata_default = re.sub(
                r"^\((.+)\)$", r"\1", rendered_metadata_default
            )

            rendered_metadata_default = re.sub(
                r"^\"?'(.+)'\"?$", r"\1", rendered_metadata_default
            )

        if rendered_inspector_default is not None:
            rendered_inspector_default = re.sub(
                r"^\((.+)\)$", r"\1", rendered_inspector_default
            )

            rendered_inspector_default = re.sub(
                r"^\"?'(.+)'\"?$", r"\1", rendered_inspector_default
            )

            rendered_inspector_default = rendered_inspector_default.strip()
        return rendered_inspector_default != rendered_metadata_default

    def emit_begin(self) -> None:
        self._exec("SET TRANSACTION READ WRITE")

    def emit_commit(self) -> None:
        self._exec("COMMIT")


@compiles(AddColumn, "oracle")
def visit_add_column(
    element: AddColumn, compiler: OracleDDLCompiler, **kw
) -> str:
    return "%s %s" % (
        alter_table(compiler, element.table_name, element.schema),
        add_column(compiler, element.column, **kw),
    )


@compiles(ColumnNullable, "oracle")
def visit_column_nullable(
    element: ColumnNullable, compiler: OracleDDLCompiler, **kw
) -> str:
    return "%s %s %s" % (
        alter_table(compiler, element.table_name, element.schema),
        alter_column(compiler, element.column_name),
        "NULL" if element.nullable else "NOT NULL",
    )


@compiles(ColumnType, "oracle")
def visit_column_type(
    element: ColumnType, compiler: OracleDDLCompiler, **kw
) -> str:
    return "%s %s %s" % (
        alter_table(compiler, element.table_name, element.schema),
        alter_column(compiler, element.column_name),
        "%s" % format_type(compiler, element.type_),
    )


@compiles(ColumnName, "oracle")
def visit_column_name(
    element: ColumnName, compiler: OracleDDLCompiler, **kw
) -> str:
    return "%s RENAME COLUMN %s TO %s" % (
        alter_table(compiler, element.table_name, element.schema),
        format_column_name(compiler, element.column_name),
        format_column_name(compiler, element.newname),
    )


@compiles(ColumnDefault, "oracle")
def visit_column_default(
    element: ColumnDefault, compiler: OracleDDLCompiler, **kw
) -> str:
    return "%s %s %s" % (
        alter_table(compiler, element.table_name, element.schema),
        alter_column(compiler, element.column_name),
        "DEFAULT %s" % format_server_default(compiler, element.default)
        if element.default is not None
        else "DEFAULT NULL",
    )


@compiles(ColumnComment, "oracle")
def visit_column_comment(
    element: ColumnComment, compiler: OracleDDLCompiler, **kw
) -> str:
    ddl = "COMMENT ON COLUMN {table_name}.{column_name} IS {comment}"

    comment = compiler.sql_compiler.render_literal_value(
        (element.comment if element.comment is not None else ""),
        sqltypes.String(),
    )

    return ddl.format(
        table_name=element.table_name,
        column_name=element.column_name,
        comment=comment,
    )


@compiles(RenameTable, "oracle")
def visit_rename_table(
    element: RenameTable, compiler: OracleDDLCompiler, **kw
) -> str:
    return "%s RENAME TO %s" % (
        alter_table(compiler, element.table_name, element.schema),
        format_table_name(compiler, element.new_table_name, None),
    )


def alter_column(compiler: OracleDDLCompiler, name: str) -> str:
    return "MODIFY %s" % format_column_name(compiler, name)


def add_column(compiler: OracleDDLCompiler, column: Column, **kw) -> str:
    return "ADD %s" % compiler.get_column_specification(column, **kw)


@compiles(IdentityColumnDefault, "oracle")
def visit_identity_column(
    element: IdentityColumnDefault, compiler: OracleDDLCompiler, **kw
):
    text = "%s %s " % (
        alter_table(compiler, element.table_name, element.schema),
        alter_column(compiler, element.column_name),
    )
    if element.default is None:
        # drop identity
        text += "DROP IDENTITY"
        return text
    else:
        text += compiler.visit_identity_column(element.default)
        return text