Skip to content

Commit 61eda9f

Browse files
committed
Implement ExceptionCollectorListener and make it default behaviour
1 parent 73a1b82 commit 61eda9f

File tree

1 file changed

+128
-12
lines changed
  • cratedb_sqlparse_py/cratedb_sqlparse

1 file changed

+128
-12
lines changed

cratedb_sqlparse_py/cratedb_sqlparse/parser.py

Lines changed: 128 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
1+
import logging
12
from typing import List
23

3-
from antlr4 import CommonTokenStream, InputStream, Token
4+
from antlr4 import CommonTokenStream, InputStream, RecognitionException, Token
45
from antlr4.error.ErrorListener import ErrorListener
56

67
from cratedb_sqlparse.generated_parser.SqlBaseLexer import SqlBaseLexer
@@ -30,7 +31,51 @@ def END_DOLLAR_QUOTED_STRING_sempred(self, localctx, predIndex) -> bool:
3031

3132

3233
class ParsingException(Exception):
33-
pass
34+
def __init__(self, *, query: str, msg: str, offending_token: Token, e: RecognitionException):
35+
self.message = msg
36+
self.offending_token = offending_token
37+
self.e = e
38+
self.query = query
39+
40+
@property
41+
def error_message(self):
42+
return f"{self!r}[line {self.line}:{self.column} {self.message}]"
43+
44+
@property
45+
def original_query_with_error_marked(self):
46+
query = self.offending_token.source[1].strdata
47+
offending_token_text: str = query[self.offending_token.start : self.offending_token.stop + 1]
48+
query_lines: list = query.split("\n")
49+
50+
offending_line: str = query_lines[self.line - 1]
51+
52+
# White spaces from the beginning of the offending line to the offending text, so the '^'
53+
# chars are correctly placed below the offending token.
54+
newline_offset = offending_line.index(offending_token_text)
55+
newline = (
56+
offending_line
57+
+ "\n"
58+
+ (" " * newline_offset + "^" * (self.offending_token.stop - self.offending_token.start + 1))
59+
)
60+
61+
query_lines[self.line - 1] = newline
62+
63+
msg = "\n".join(query_lines)
64+
return msg
65+
66+
@property
67+
def column(self):
68+
return self.offending_token.column
69+
70+
@property
71+
def line(self):
72+
return self.offending_token.line
73+
74+
def __repr__(self):
75+
return f"{type(self.e).__qualname__}"
76+
77+
def __str__(self):
78+
return repr(self)
3479

3580

3681
class CaseInsensitiveStream(InputStream):
@@ -47,16 +92,44 @@ class ExceptionErrorListener(ErrorListener):
4792
"""
4893

4994
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
50-
raise ParsingException(f"line{line}:{column} {msg}")
95+
error = ParsingException(
96+
msg=msg,
97+
offending_token=offendingSymbol,
98+
e=e,
99+
query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex),
100+
)
101+
raise error
102+
103+
104+
class ExceptionCollectorListener(ErrorListener):
105+
"""
106+
Error listener that collects all errors into errors for further processing.
107+
108+
Based partially on https://github.com/antlr/antlr4/issues/396
109+
"""
110+
111+
def __init__(self):
112+
self.errors = []
113+
114+
def syntaxError(self, recognizer, offendingSymbol, line, column, msg, e):
115+
error = ParsingException(
116+
msg=msg,
117+
offending_token=offendingSymbol,
118+
e=e,
119+
query=e.ctx.parser.getTokenStream().getText(e.ctx.start, e.offendingToken.tokenIndex),
120+
)
121+
122+
self.errors.append(error)
51123

52124

53125
class Statement:
54126
"""
55127
Represents a CrateDB SQL statement.
56128
"""
57129

58-
def __init__(self, ctx: SqlBaseParser.StatementContext):
130+
def __init__(self, ctx: SqlBaseParser.StatementContext, exception: ParsingException = None):
59131
self.ctx: SqlBaseParser.StatementContext = ctx
132+
self.exception = exception
60133

61134
@property
62135
def tree(self):
@@ -77,7 +150,7 @@ def query(self) -> str:
77150
"""
78151
Returns the query, comments and ';' are not included.
79152
"""
80-
return self.ctx.parser.getTokenStream().getText(start=self.ctx.start.tokenIndex, stop=self.ctx.stop.tokenIndex)
153+
return self.ctx.parser.getTokenStream().getText(start=self.ctx.start, stop=self.ctx.stop)
81154

82155
@property
83156
def type(self):
@@ -90,7 +163,20 @@ def __repr__(self):
90163
return f'{self.__class__.__qualname__}<{self.query if len(self.query) < 15 else self.query[:15] + "..."}>'
91164

92165

93-
def sqlparse(query: str) -> List[Statement]:
166+
def find_suitable_error(statement, errors):
167+
for error in errors[:]:
168+
# We clean the error_query of ';' and spaces because ironically,
169+
# we can get the full query in the error handler but not in the context.
170+
error_query = error.query
171+
if error_query.endswith(";"):
172+
error_query = error_query[: len(error_query) - 1]
173+
174+
if error_query.lstrip().rstrip() == statement.query:
175+
statement.exception = error
176+
errors.pop(errors.index(error))
177+
178+
179+
def sqlparse(query: str, raise_exception: bool = False) -> List[Statement]:
94180
"""
95181
Parses a string into SQL `Statement`.
96182
"""
@@ -101,12 +187,42 @@ def sqlparse(query: str) -> List[Statement]:
101187

102188
parser = SqlBaseParser(stream)
103189
parser.removeErrorListeners()
104-
parser.addErrorListener(ExceptionErrorListener())
190+
error_listener = ExceptionErrorListener() if raise_exception else ExceptionCollectorListener()
191+
parser.addErrorListener(error_listener)
105192

106193
tree = parser.statements()
107194

108-
# At this point, all errors are already raised; it's seasonably safe to assume
109-
# that the statements are valid.
110-
statements = list(filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children))
111-
112-
return [Statement(statement) for statement in statements]
195+
statements_context: list[SqlBaseParser.StatementContext] = list(
196+
filter(lambda children: isinstance(children, SqlBaseParser.StatementContext), tree.children)
197+
)
198+
199+
statements = []
200+
for statement_context in statements_context:
201+
_stmt = Statement(statement_context)
202+
find_suitable_error(_stmt, error_listener.errors)
203+
statements.append(_stmt)
204+
205+
else:
206+
# We might still have error(s) that we couldn't match with their origin statement,
207+
# this happens when the query is composed of only one keyword, e.g. 'SELCT 1'
208+
# the error.query will be 'SELCT' instead of 'SELCT 1'.
209+
if len(error_listener.errors) == 1:
210+
# This case has an edge case where we hypothetically assign the
211+
# wrong error to a statement, for example:
212+
# SELECT A FROM tbl1;
213+
# SELEC 1;
214+
# This would match both conditionals, this however is protected by
215+
# by https://github.com/crate/cratedb-sqlparse/issues/28, but might
216+
# change in the future.
217+
error = error_listener.errors[0]
218+
for _stmt in statements:
219+
if _stmt.exception is None and error.query in _stmt.query:
220+
_stmt.exception = error
221+
break
222+
223+
if len(error_listener.errors) > 1:
224+
logging.error(
225+
"Could not match errors to queries, too much ambiguity, open an issue with this error and the query."
226+
)
227+
228+
return statements

0 commit comments

Comments
 (0)