99# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1010# See the License for the specific language governing permissions and
1111# limitations under the License.
12+ from typing import Any , Optional
13+
1214from sqlalchemy .sql import compiler
1315from sqlalchemy .sql .base import DialectKWArgs
16+ from sqlalchemy .sql .schema import Table
1417
1518# https://trino.io/docs/current/language/reserved.html
1619RESERVED_WORDS = {
9295
9396
9497class TrinoSQLCompiler (compiler .SQLCompiler ):
95- def limit_clause (self , select , ** kw ) :
98+ def limit_clause (self , select : Any , ** kw : dict [ str , Any ]) -> str :
9699 """
97100 Trino support only OFFSET...LIMIT but not LIMIT...OFFSET syntax.
98101 """
@@ -103,15 +106,15 @@ def limit_clause(self, select, **kw):
103106 text += "\n LIMIT " + self .process (select ._limit_clause , ** kw )
104107 return text
105108
106- def visit_table (self , table , asfrom = False , iscrud = False , ashint = False ,
107- fromhints = None , use_schema = True , ** kwargs ) :
109+ def visit_table (self , table : Table , asfrom : bool = False , iscrud : bool = False , ashint : bool = False ,
110+ fromhints : Optional [ Any ] = None , use_schema : bool = True , ** kwargs : Any ) -> str :
108111 sql = super (TrinoSQLCompiler , self ).visit_table (
109112 table , asfrom , iscrud , ashint , fromhints , use_schema , ** kwargs
110113 )
111114 return self .add_catalog (sql , table )
112115
113116 @staticmethod
114- def add_catalog (sql , table ) :
117+ def add_catalog (sql : str , table : Table ) -> str :
115118 if table is None or not isinstance (table , DialectKWArgs ):
116119 return sql
117120
@@ -131,7 +134,7 @@ class TrinoDDLCompiler(compiler.DDLCompiler):
131134
132135
133136class TrinoTypeCompiler (compiler .GenericTypeCompiler ):
134- def visit_FLOAT (self , type_ , ** kw ) :
137+ def visit_FLOAT (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
135138 precision = type_ .precision or 32
136139 if 0 <= precision <= 32 :
137140 return self .visit_REAL (type_ , ** kw )
@@ -140,37 +143,37 @@ def visit_FLOAT(self, type_, **kw):
140143 else :
141144 raise ValueError (f"type.precision must be in range [0, 64], got { type_ .precision } " )
142145
143- def visit_DOUBLE (self , type_ , ** kw ) :
146+ def visit_DOUBLE (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
144147 return "DOUBLE"
145148
146- def visit_NUMERIC (self , type_ , ** kw ) :
149+ def visit_NUMERIC (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
147150 return self .visit_DECIMAL (type_ , ** kw )
148151
149- def visit_NCHAR (self , type_ , ** kw ) :
152+ def visit_NCHAR (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
150153 return self .visit_CHAR (type_ , ** kw )
151154
152- def visit_NVARCHAR (self , type_ , ** kw ) :
155+ def visit_NVARCHAR (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
153156 return self .visit_VARCHAR (type_ , ** kw )
154157
155- def visit_TEXT (self , type_ , ** kw ) :
158+ def visit_TEXT (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
156159 return self .visit_VARCHAR (type_ , ** kw )
157160
158- def visit_BINARY (self , type_ , ** kw ) :
161+ def visit_BINARY (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
159162 return self .visit_VARBINARY (type_ , ** kw )
160163
161- def visit_CLOB (self , type_ , ** kw ) :
164+ def visit_CLOB (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
162165 return self .visit_VARCHAR (type_ , ** kw )
163166
164- def visit_NCLOB (self , type_ , ** kw ) :
167+ def visit_NCLOB (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
165168 return self .visit_VARCHAR (type_ , ** kw )
166169
167- def visit_BLOB (self , type_ , ** kw ) :
170+ def visit_BLOB (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
168171 return self .visit_VARBINARY (type_ , ** kw )
169172
170- def visit_DATETIME (self , type_ , ** kw ) :
173+ def visit_DATETIME (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
171174 return self .visit_TIMESTAMP (type_ , ** kw )
172175
173- def visit_TIMESTAMP (self , type_ , ** kw ) :
176+ def visit_TIMESTAMP (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
174177 datatype = "TIMESTAMP"
175178 precision = getattr (type_ , "precision" , None )
176179 if precision not in range (0 , 13 ) and precision is not None :
@@ -182,7 +185,7 @@ def visit_TIMESTAMP(self, type_, **kw):
182185
183186 return datatype
184187
185- def visit_TIME (self , type_ , ** kw ) :
188+ def visit_TIME (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
186189 datatype = "TIME"
187190 precision = getattr (type_ , "precision" , None )
188191 if precision not in range (0 , 13 ) and precision is not None :
@@ -193,13 +196,13 @@ def visit_TIME(self, type_, **kw):
193196 datatype += " WITH TIME ZONE"
194197 return datatype
195198
196- def visit_JSON (self , type_ , ** kw ) :
199+ def visit_JSON (self , type_ : Any , ** kw : dict [ str , Any ]) -> str :
197200 return 'JSON'
198201
199202
200203class TrinoIdentifierPreparer (compiler .IdentifierPreparer ):
201204 reserved_words = RESERVED_WORDS
202205
203- def format_table (self , table , use_schema = True , name = None ):
206+ def format_table (self , table : Table , use_schema : bool = True , name : Optional [ str ] = None ) -> str :
204207 result = super (TrinoIdentifierPreparer , self ).format_table (table , use_schema , name )
205208 return TrinoSQLCompiler .add_catalog (result , table )
0 commit comments