@@ -37,6 +37,7 @@ use crate::tokenizer::*;
37
37
pub enum ParserError {
38
38
TokenizerError ( String ) ,
39
39
ParserError ( String ) ,
40
+ RecursionLimitExceeded ,
40
41
}
41
42
42
43
// Use `Parser::expected` instead, if possible
@@ -55,6 +56,92 @@ macro_rules! return_ok_if_some {
55
56
} } ;
56
57
}
57
58
59
+ #[ cfg( feature = "std" ) ]
60
+ /// Implemenation [`RecursionCounter`] if std is available
61
+ mod recursion {
62
+ use core:: sync:: atomic:: { AtomicUsize , Ordering } ;
63
+ use std:: rc:: Rc ;
64
+
65
+ use super :: ParserError ;
66
+
67
+ /// Tracks remaining recursion depth. This value is decremented on
68
+ /// each call to `try_decrease()`, when it reaches 0 an error will
69
+ /// be returned.
70
+ ///
71
+ /// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust
72
+ /// borrow checker so the automatic DepthGuard decrement a
73
+ /// reference to the counter. The actual value is not modified
74
+ /// concurrently
75
+ pub ( crate ) struct RecursionCounter {
76
+ remaining_depth : Rc < AtomicUsize > ,
77
+ }
78
+
79
+ impl RecursionCounter {
80
+ /// Creates a [`RecursionCounter`] with the specified maximum
81
+ /// depth
82
+ pub fn new ( remaining_depth : usize ) -> Self {
83
+ Self {
84
+ remaining_depth : Rc :: new ( remaining_depth. into ( ) ) ,
85
+ }
86
+ }
87
+
88
+ /// Decreases the remaining depth by 1.
89
+ ///
90
+ /// Returns `Err` if the remaining depth falls to 0.
91
+ ///
92
+ /// Returns a [`DepthGuard`] which will adds 1 to the
93
+ /// remaining depth upon drop;
94
+ pub fn try_decrease ( & self ) -> Result < DepthGuard , ParserError > {
95
+ let old_value = self . remaining_depth . fetch_sub ( 1 , Ordering :: SeqCst ) ;
96
+ // ran out of space
97
+ if old_value == 0 {
98
+ Err ( ParserError :: RecursionLimitExceeded )
99
+ } else {
100
+ Ok ( DepthGuard :: new ( Rc :: clone ( & self . remaining_depth ) ) )
101
+ }
102
+ }
103
+ }
104
+
105
+ /// Guard that increass the remaining depth by 1 on drop
106
+ pub struct DepthGuard {
107
+ remaining_depth : Rc < AtomicUsize > ,
108
+ }
109
+
110
+ impl DepthGuard {
111
+ fn new ( remaining_depth : Rc < AtomicUsize > ) -> Self {
112
+ Self { remaining_depth }
113
+ }
114
+ }
115
+ impl Drop for DepthGuard {
116
+ fn drop ( & mut self ) {
117
+ self . remaining_depth . fetch_add ( 1 , Ordering :: SeqCst ) ;
118
+ }
119
+ }
120
+ }
121
+
122
+ #[ cfg( not( feature = "std" ) ) ]
123
+ mod recursion {
124
+ /// Implemenation [`RecursionCounter`] if std is NOT available (and does not
125
+ /// guard against stack overflow).
126
+ ///
127
+ /// Has the same API as the std RecursionCounter implementation
128
+ /// but does not actually limit stack depth.
129
+ pub ( crate ) struct RecursionCounter { }
130
+
131
+ impl RecursionCounter {
132
+ pub fn new ( _remaining_depth : usize ) -> Self {
133
+ Self { }
134
+ }
135
+ pub fn try_decrease ( & self ) -> Result < DepthGuard , super :: ParserError > {
136
+ Ok ( DepthGuard { } )
137
+ }
138
+ }
139
+
140
+ pub struct DepthGuard { }
141
+ }
142
+
143
+ use recursion:: RecursionCounter ;
144
+
58
145
#[ derive( PartialEq , Eq ) ]
59
146
pub enum IsOptional {
60
147
Optional ,
@@ -96,6 +183,7 @@ impl fmt::Display for ParserError {
96
183
match self {
97
184
ParserError :: TokenizerError ( s) => s,
98
185
ParserError :: ParserError ( s) => s,
186
+ ParserError :: RecursionLimitExceeded => "recursion limit exceeded" ,
99
187
}
100
188
)
101
189
}
@@ -104,22 +192,78 @@ impl fmt::Display for ParserError {
104
192
#[ cfg( feature = "std" ) ]
105
193
impl std:: error:: Error for ParserError { }
106
194
195
+ // By default, allow expressions up to this deep before erroring
196
+ const DEFAULT_REMAINING_DEPTH : usize = 50 ;
197
+
107
198
pub struct Parser < ' a > {
108
199
tokens : Vec < TokenWithLocation > ,
109
200
/// The index of the first unprocessed token in `self.tokens`
110
201
index : usize ,
202
+ /// The current dialect to use
111
203
dialect : & ' a dyn Dialect ,
204
+ /// ensure the stack does not overflow by limiting recusion depth
205
+ recursion_counter : RecursionCounter ,
112
206
}
113
207
114
208
impl < ' a > Parser < ' a > {
115
- /// Parse the specified tokens
116
- /// To avoid breaking backwards compatibility, this function accepts
117
- /// bare tokens.
118
- pub fn new ( tokens : Vec < Token > , dialect : & ' a dyn Dialect ) -> Self {
119
- Parser :: new_without_locations ( tokens, dialect)
209
+ /// Create a parser for a [`Dialect`]
210
+ ///
211
+ /// See also [`Parser::parse_sql`]
212
+ ///
213
+ /// Example:
214
+ /// ```
215
+ /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
216
+ /// # fn main() -> Result<(), ParserError> {
217
+ /// let dialect = GenericDialect{};
218
+ /// let statements = Parser::new(&dialect)
219
+ /// .try_with_sql("SELECT * FROM foo")?
220
+ /// .parse_statements()?;
221
+ /// # Ok(())
222
+ /// # }
223
+ /// ```
224
+ pub fn new ( dialect : & ' a dyn Dialect ) -> Self {
225
+ Self {
226
+ tokens : vec ! [ ] ,
227
+ index : 0 ,
228
+ dialect,
229
+ recursion_counter : RecursionCounter :: new ( DEFAULT_REMAINING_DEPTH ) ,
230
+ }
231
+ }
232
+
233
+ /// Specify the maximum recursion limit while parsing.
234
+ ///
235
+ ///
236
+ /// [`Parser`] prevents stack overflows by returning
237
+ /// [`ParserError::RecursionLimitExceeded`] if the parser exceeds
238
+ /// this depth while processing the query.
239
+ ///
240
+ /// Example:
241
+ /// ```
242
+ /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
243
+ /// # fn main() -> Result<(), ParserError> {
244
+ /// let dialect = GenericDialect{};
245
+ /// let result = Parser::new(&dialect)
246
+ /// .with_recursion_limit(1)
247
+ /// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")?
248
+ /// .parse_statements();
249
+ /// assert_eq!(result, Err(ParserError::RecursionLimitExceeded));
250
+ /// # Ok(())
251
+ /// # }
252
+ /// ```
253
+ pub fn with_recursion_limit ( mut self , recursion_limit : usize ) -> Self {
254
+ self . recursion_counter = RecursionCounter :: new ( recursion_limit) ;
255
+ self
256
+ }
257
+
258
+ /// Reset this parser to parse the specified token stream
259
+ pub fn with_tokens_with_locations ( mut self , tokens : Vec < TokenWithLocation > ) -> Self {
260
+ self . tokens = tokens;
261
+ self . index = 0 ;
262
+ self
120
263
}
121
264
122
- pub fn new_without_locations ( tokens : Vec < Token > , dialect : & ' a dyn Dialect ) -> Self {
265
+ /// Reset this parser state to parse the specified tokens
266
+ pub fn with_tokens ( self , tokens : Vec < Token > ) -> Self {
123
267
// Put in dummy locations
124
268
let tokens_with_locations: Vec < TokenWithLocation > = tokens
125
269
. into_iter ( )
@@ -128,49 +272,84 @@ impl<'a> Parser<'a> {
128
272
location : Location { line : 0 , column : 0 } ,
129
273
} )
130
274
. collect ( ) ;
131
- Parser :: new_with_locations ( tokens_with_locations, dialect )
275
+ self . with_tokens_with_locations ( tokens_with_locations)
132
276
}
133
277
134
- /// Parse the specified tokens
135
- pub fn new_with_locations ( tokens : Vec < TokenWithLocation > , dialect : & ' a dyn Dialect ) -> Self {
136
- Parser {
137
- tokens,
138
- index : 0 ,
139
- dialect,
140
- }
278
+ /// Tokenize the sql string and sets this [`Parser`]'s state to
279
+ /// parse the resulting tokens
280
+ ///
281
+ /// Returns an error if there was an error tokenizing the SQL string.
282
+ ///
283
+ /// See example on [`Parser::new()`] for an example
284
+ pub fn try_with_sql ( self , sql : & str ) -> Result < Self , ParserError > {
285
+ debug ! ( "Parsing sql '{}'..." , sql) ;
286
+ let mut tokenizer = Tokenizer :: new ( self . dialect , sql) ;
287
+ let tokens = tokenizer. tokenize ( ) ?;
288
+ Ok ( self . with_tokens ( tokens) )
141
289
}
142
290
143
- /// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
144
- pub fn parse_sql ( dialect : & dyn Dialect , sql : & str ) -> Result < Vec < Statement > , ParserError > {
145
- let mut tokenizer = Tokenizer :: new ( dialect, sql) ;
146
- let tokens = tokenizer. tokenize ( ) ?;
147
- let mut parser = Parser :: new ( tokens, dialect) ;
291
+ /// Parse potentially multiple statements
292
+ ///
293
+ /// Example
294
+ /// ```
295
+ /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
296
+ /// # fn main() -> Result<(), ParserError> {
297
+ /// let dialect = GenericDialect{};
298
+ /// let statements = Parser::new(&dialect)
299
+ /// // Parse a SQL string with 2 separate statements
300
+ /// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")?
301
+ /// .parse_statements()?;
302
+ /// assert_eq!(statements.len(), 2);
303
+ /// # Ok(())
304
+ /// # }
305
+ /// ```
306
+ pub fn parse_statements ( & mut self ) -> Result < Vec < Statement > , ParserError > {
148
307
let mut stmts = Vec :: new ( ) ;
149
308
let mut expecting_statement_delimiter = false ;
150
- debug ! ( "Parsing sql '{}'..." , sql) ;
151
309
loop {
152
310
// ignore empty statements (between successive statement delimiters)
153
- while parser . consume_token ( & Token :: SemiColon ) {
311
+ while self . consume_token ( & Token :: SemiColon ) {
154
312
expecting_statement_delimiter = false ;
155
313
}
156
314
157
- if parser . peek_token ( ) == Token :: EOF {
315
+ if self . peek_token ( ) == Token :: EOF {
158
316
break ;
159
317
}
160
318
if expecting_statement_delimiter {
161
- return parser . expected ( "end of statement" , parser . peek_token ( ) ) ;
319
+ return self . expected ( "end of statement" , self . peek_token ( ) ) ;
162
320
}
163
321
164
- let statement = parser . parse_statement ( ) ?;
322
+ let statement = self . parse_statement ( ) ?;
165
323
stmts. push ( statement) ;
166
324
expecting_statement_delimiter = true ;
167
325
}
168
326
Ok ( stmts)
169
327
}
170
328
329
+ /// Convience method to parse a string with one or more SQL
330
+ /// statements into produce an Abstract Syntax Tree (AST).
331
+ ///
332
+ /// Example
333
+ /// ```
334
+ /// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
335
+ /// # fn main() -> Result<(), ParserError> {
336
+ /// let dialect = GenericDialect{};
337
+ /// let statements = Parser::parse_sql(
338
+ /// &dialect, "SELECT * FROM foo"
339
+ /// )?;
340
+ /// assert_eq!(statements.len(), 1);
341
+ /// # Ok(())
342
+ /// # }
343
+ /// ```
344
+ pub fn parse_sql ( dialect : & dyn Dialect , sql : & str ) -> Result < Vec < Statement > , ParserError > {
345
+ Parser :: new ( dialect) . try_with_sql ( sql) ?. parse_statements ( )
346
+ }
347
+
171
348
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
172
349
/// stopping before the statement separator, if any.
173
350
pub fn parse_statement ( & mut self ) -> Result < Statement , ParserError > {
351
+ let _guard = self . recursion_counter . try_decrease ( ) ?;
352
+
174
353
// allow the dialect to override statement parsing
175
354
if let Some ( statement) = self . dialect . parse_statement ( self ) {
176
355
return statement;
@@ -364,6 +543,7 @@ impl<'a> Parser<'a> {
364
543
365
544
/// Parse a new expression
366
545
pub fn parse_expr ( & mut self ) -> Result < Expr , ParserError > {
546
+ let _guard = self . recursion_counter . try_decrease ( ) ?;
367
547
self . parse_subexpr ( 0 )
368
548
}
369
549
@@ -4512,6 +4692,7 @@ impl<'a> Parser<'a> {
4512
4692
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
4513
4693
/// expect the initial keyword to be already consumed
4514
4694
pub fn parse_query ( & mut self ) -> Result < Query , ParserError > {
4695
+ let _guard = self . recursion_counter . try_decrease ( ) ?;
4515
4696
let with = if self . parse_keyword ( Keyword :: WITH ) {
4516
4697
Some ( With {
4517
4698
recursive : self . parse_keyword ( Keyword :: RECURSIVE ) ,
0 commit comments