1+ import logging
12from typing import List
23
3- from antlr4 import CommonTokenStream , InputStream , Token
4+ from antlr4 import CommonTokenStream , InputStream , RecognitionException , Token
45from antlr4 .error .ErrorListener import ErrorListener
56
67from cratedb_sqlparse .generated_parser .SqlBaseLexer import SqlBaseLexer
@@ -30,7 +31,51 @@ def END_DOLLAR_QUOTED_STRING_sempred(self, localctx, predIndex) -> bool:
3031
3132
3233class ParsingException (Exception ):
33- pass
34+ def __init__ (self , * , query : str , msg : str , offending_token : Token , e : RecognitionException ):
35+ self .message = msg
36+ self .offending_token = offending_token
37+ self .e = e
38+ self .query = query
39+
40+ @property
41+ def error_message (self ):
42+ return f"{ self !r} [line { self .line } :{ self .column } { self .message } ]"
43+
44+ @property
45+ def original_query_with_error_marked (self ):
46+ query = self .offending_token .source [1 ].strdata
47+ offending_token_text : str = query [self .offending_token .start : self .offending_token .stop + 1 ]
48+ query_lines : list = query .split ("\n " )
49+
50+ offending_line : str = query_lines [self .line - 1 ]
51+
52+ # White spaces from the beginning of the offending line to the offending text, so the '^'
53+ # chars are correctly placed below the offending token.
54+ newline_offset = offending_line .index (offending_token_text )
55+ newline = (
56+ offending_line
57+ + "\n "
58+ + (" " * newline_offset + "^" * (self .offending_token .stop - self .offending_token .start + 1 ))
59+ )
60+
61+ query_lines [self .line - 1 ] = newline
62+
63+ msg = "\n " .join (query_lines )
64+ return msg
65+
66+ @property
67+ def column (self ):
68+ return self .offending_token .column
69+
70+ @property
71+ def line (self ):
72+ return self .offending_token .line
73+
74+ def __repr__ (self ):
75+ return f"{ type (self .e ).__qualname__ } "
76+
77+ def __str__ (self ):
78+ return repr (self )
3479
3580
3681class CaseInsensitiveStream (InputStream ):
@@ -47,16 +92,44 @@ class ExceptionErrorListener(ErrorListener):
4792 """
4893
4994 def syntaxError (self , recognizer , offendingSymbol , line , column , msg , e ):
50- raise ParsingException (f"line{ line } :{ column } { msg } " )
95+ error = ParsingException (
96+ msg = msg ,
97+ offending_token = offendingSymbol ,
98+ e = e ,
99+ query = e .ctx .parser .getTokenStream ().getText (e .ctx .start , e .offendingToken .tokenIndex ),
100+ )
101+ raise error
102+
103+
104+ class ExceptionCollectorListener (ErrorListener ):
105+ """
106+ Error listener that collects all errors into errors for further processing.
107+
108+ Based partially on https://github.com/antlr/antlr4/issues/396
109+ """
110+
111+ def __init__ (self ):
112+ self .errors = []
113+
114+ def syntaxError (self , recognizer , offendingSymbol , line , column , msg , e ):
115+ error = ParsingException (
116+ msg = msg ,
117+ offending_token = offendingSymbol ,
118+ e = e ,
119+ query = e .ctx .parser .getTokenStream ().getText (e .ctx .start , e .offendingToken .tokenIndex ),
120+ )
121+
122+ self .errors .append (error )
51123
52124
53125class Statement :
54126 """
55127 Represents a CrateDB SQL statement.
56128 """
57129
58- def __init__ (self , ctx : SqlBaseParser .StatementContext ):
130+ def __init__ (self , ctx : SqlBaseParser .StatementContext , exception : ParsingException = None ):
59131 self .ctx : SqlBaseParser .StatementContext = ctx
132+ self .exception = exception
60133
61134 @property
62135 def tree (self ):
@@ -77,7 +150,7 @@ def query(self) -> str:
77150 """
78151 Returns the query, comments and ';' are not included.
79152 """
80- return self .ctx .parser .getTokenStream ().getText (start = self .ctx .start . tokenIndex , stop = self .ctx .stop . tokenIndex )
153+ return self .ctx .parser .getTokenStream ().getText (start = self .ctx .start , stop = self .ctx .stop )
81154
82155 @property
83156 def type (self ):
@@ -90,7 +163,20 @@ def __repr__(self):
90163 return f'{ self .__class__ .__qualname__ } <{ self .query if len (self .query ) < 15 else self .query [:15 ] + "..." } >'
91164
92165
93- def sqlparse (query : str ) -> List [Statement ]:
166+ def find_suitable_error (statement , errors ):
167+ for error in errors [:]:
168+ # We clean the error_query of ';' and spaces because ironically,
169+ # we can get the full query in the error handler but not in the context.
170+ error_query = error .query
171+ if error_query .endswith (";" ):
172+ error_query = error_query [: len (error_query ) - 1 ]
173+
174+ if error_query .lstrip ().rstrip () == statement .query :
175+ statement .exception = error
176+ errors .pop (errors .index (error ))
177+
178+
179+ def sqlparse (query : str , raise_exception : bool = False ) -> List [Statement ]:
94180 """
95181 Parses a string into SQL `Statement`.
96182 """
@@ -101,12 +187,42 @@ def sqlparse(query: str) -> List[Statement]:
101187
102188 parser = SqlBaseParser (stream )
103189 parser .removeErrorListeners ()
104- parser .addErrorListener (ExceptionErrorListener ())
190+ error_listener = ExceptionErrorListener () if raise_exception else ExceptionCollectorListener ()
191+ parser .addErrorListener (error_listener )
105192
106193 tree = parser .statements ()
107194
108- # At this point, all errors are already raised; it's seasonably safe to assume
109- # that the statements are valid.
110- statements = list (filter (lambda children : isinstance (children , SqlBaseParser .StatementContext ), tree .children ))
111-
112- return [Statement (statement ) for statement in statements ]
195+ statements_context : list [SqlBaseParser .StatementContext ] = list (
196+ filter (lambda children : isinstance (children , SqlBaseParser .StatementContext ), tree .children )
197+ )
198+
199+ statements = []
200+ for statement_context in statements_context :
201+ _stmt = Statement (statement_context )
202+ find_suitable_error (_stmt , error_listener .errors )
203+ statements .append (_stmt )
204+
205+ else :
206+ # We might still have error(s) that we couldn't match with their origin statement,
207+ # this happens when the query is composed of only one keyword, e.g. 'SELCT 1'
208+ # the error.query will be 'SELCT' instead of 'SELCT 1'.
209+ if len (error_listener .errors ) == 1 :
210+ # This case has an edge case where we hypothetically assign the
211+ # wrong error to a statement, for example:
212+ # SELECT A FROM tbl1;
213+ # SELEC 1;
214+ # This would match both conditionals, this however is protected by
215+ # by https://github.com/crate/cratedb-sqlparse/issues/28, but might
216+ # change in the future.
217+ error = error_listener .errors [0 ]
218+ for _stmt in statements :
219+ if _stmt .exception is None and error .query in _stmt .query :
220+ _stmt .exception = error
221+ break
222+
223+ if len (error_listener .errors ) > 1 :
224+ logging .error (
225+ "Could not match errors to queries, too much ambiguity, open an issue with this error and the query."
226+ )
227+
228+ return statements
0 commit comments