Skip to content

Commit 79d0baa

Browse files
authored
Add configurable recursion limit to parser, to protect against stack overflows (#764)
1 parent 2c20ec0 commit 79d0baa

File tree

4 files changed

+301
-30
lines changed

4 files changed

+301
-30
lines changed

src/lib.rs

+6-2
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,15 @@
1212

1313
//! SQL Parser for Rust
1414
//!
15-
//! Example code:
16-
//!
1715
//! This crate provides an ANSI:SQL 2011 lexer and parser that can parse SQL
1816
//! into an Abstract Syntax Tree (AST).
1917
//!
18+
//! See [`Parser::parse_sql`](crate::parser::Parser::parse_sql) and
19+
//! [`Parser::new`](crate::parser::Parser::new) for the Parsing API
20+
//! and the [`ast`](crate::ast) crate for the AST structure.
21+
//!
22+
//! Example:
23+
//!
2024
//! ```
2125
//! use sqlparser::dialect::GenericDialect;
2226
//! use sqlparser::parser::Parser;

src/parser.rs

+205-24
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@ use crate::tokenizer::*;
3737
pub enum ParserError {
3838
TokenizerError(String),
3939
ParserError(String),
40+
RecursionLimitExceeded,
4041
}
4142

4243
// Use `Parser::expected` instead, if possible
@@ -55,6 +56,92 @@ macro_rules! return_ok_if_some {
5556
}};
5657
}
5758

59+
#[cfg(feature = "std")]
60+
/// Implemenation [`RecursionCounter`] if std is available
61+
mod recursion {
62+
use core::sync::atomic::{AtomicUsize, Ordering};
63+
use std::rc::Rc;
64+
65+
use super::ParserError;
66+
67+
/// Tracks remaining recursion depth. This value is decremented on
68+
/// each call to `try_decrease()`, when it reaches 0 an error will
69+
/// be returned.
70+
///
71+
/// Note: Uses an Rc and AtomicUsize in order to satisfy the Rust
72+
/// borrow checker so the automatic DepthGuard decrement a
73+
/// reference to the counter. The actual value is not modified
74+
/// concurrently
75+
pub(crate) struct RecursionCounter {
76+
remaining_depth: Rc<AtomicUsize>,
77+
}
78+
79+
impl RecursionCounter {
80+
/// Creates a [`RecursionCounter`] with the specified maximum
81+
/// depth
82+
pub fn new(remaining_depth: usize) -> Self {
83+
Self {
84+
remaining_depth: Rc::new(remaining_depth.into()),
85+
}
86+
}
87+
88+
/// Decreases the remaining depth by 1.
89+
///
90+
/// Returns `Err` if the remaining depth falls to 0.
91+
///
92+
/// Returns a [`DepthGuard`] which will adds 1 to the
93+
/// remaining depth upon drop;
94+
pub fn try_decrease(&self) -> Result<DepthGuard, ParserError> {
95+
let old_value = self.remaining_depth.fetch_sub(1, Ordering::SeqCst);
96+
// ran out of space
97+
if old_value == 0 {
98+
Err(ParserError::RecursionLimitExceeded)
99+
} else {
100+
Ok(DepthGuard::new(Rc::clone(&self.remaining_depth)))
101+
}
102+
}
103+
}
104+
105+
/// Guard that increass the remaining depth by 1 on drop
106+
pub struct DepthGuard {
107+
remaining_depth: Rc<AtomicUsize>,
108+
}
109+
110+
impl DepthGuard {
111+
fn new(remaining_depth: Rc<AtomicUsize>) -> Self {
112+
Self { remaining_depth }
113+
}
114+
}
115+
impl Drop for DepthGuard {
116+
fn drop(&mut self) {
117+
self.remaining_depth.fetch_add(1, Ordering::SeqCst);
118+
}
119+
}
120+
}
121+
122+
#[cfg(not(feature = "std"))]
123+
mod recursion {
124+
/// Implemenation [`RecursionCounter`] if std is NOT available (and does not
125+
/// guard against stack overflow).
126+
///
127+
/// Has the same API as the std RecursionCounter implementation
128+
/// but does not actually limit stack depth.
129+
pub(crate) struct RecursionCounter {}
130+
131+
impl RecursionCounter {
132+
pub fn new(_remaining_depth: usize) -> Self {
133+
Self {}
134+
}
135+
pub fn try_decrease(&self) -> Result<DepthGuard, super::ParserError> {
136+
Ok(DepthGuard {})
137+
}
138+
}
139+
140+
pub struct DepthGuard {}
141+
}
142+
143+
use recursion::RecursionCounter;
144+
58145
#[derive(PartialEq, Eq)]
59146
pub enum IsOptional {
60147
Optional,
@@ -96,6 +183,7 @@ impl fmt::Display for ParserError {
96183
match self {
97184
ParserError::TokenizerError(s) => s,
98185
ParserError::ParserError(s) => s,
186+
ParserError::RecursionLimitExceeded => "recursion limit exceeded",
99187
}
100188
)
101189
}
@@ -104,22 +192,78 @@ impl fmt::Display for ParserError {
104192
#[cfg(feature = "std")]
105193
impl std::error::Error for ParserError {}
106194

195+
// By default, allow expressions up to this deep before erroring
196+
const DEFAULT_REMAINING_DEPTH: usize = 50;
197+
107198
pub struct Parser<'a> {
108199
tokens: Vec<TokenWithLocation>,
109200
/// The index of the first unprocessed token in `self.tokens`
110201
index: usize,
202+
/// The current dialect to use
111203
dialect: &'a dyn Dialect,
204+
/// ensure the stack does not overflow by limiting recusion depth
205+
recursion_counter: RecursionCounter,
112206
}
113207

114208
impl<'a> Parser<'a> {
115-
/// Parse the specified tokens
116-
/// To avoid breaking backwards compatibility, this function accepts
117-
/// bare tokens.
118-
pub fn new(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
119-
Parser::new_without_locations(tokens, dialect)
209+
/// Create a parser for a [`Dialect`]
210+
///
211+
/// See also [`Parser::parse_sql`]
212+
///
213+
/// Example:
214+
/// ```
215+
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
216+
/// # fn main() -> Result<(), ParserError> {
217+
/// let dialect = GenericDialect{};
218+
/// let statements = Parser::new(&dialect)
219+
/// .try_with_sql("SELECT * FROM foo")?
220+
/// .parse_statements()?;
221+
/// # Ok(())
222+
/// # }
223+
/// ```
224+
pub fn new(dialect: &'a dyn Dialect) -> Self {
225+
Self {
226+
tokens: vec![],
227+
index: 0,
228+
dialect,
229+
recursion_counter: RecursionCounter::new(DEFAULT_REMAINING_DEPTH),
230+
}
231+
}
232+
233+
/// Specify the maximum recursion limit while parsing.
234+
///
235+
///
236+
/// [`Parser`] prevents stack overflows by returning
237+
/// [`ParserError::RecursionLimitExceeded`] if the parser exceeds
238+
/// this depth while processing the query.
239+
///
240+
/// Example:
241+
/// ```
242+
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
243+
/// # fn main() -> Result<(), ParserError> {
244+
/// let dialect = GenericDialect{};
245+
/// let result = Parser::new(&dialect)
246+
/// .with_recursion_limit(1)
247+
/// .try_with_sql("SELECT * FROM foo WHERE (a OR (b OR (c OR d)))")?
248+
/// .parse_statements();
249+
/// assert_eq!(result, Err(ParserError::RecursionLimitExceeded));
250+
/// # Ok(())
251+
/// # }
252+
/// ```
253+
pub fn with_recursion_limit(mut self, recursion_limit: usize) -> Self {
254+
self.recursion_counter = RecursionCounter::new(recursion_limit);
255+
self
256+
}
257+
258+
/// Reset this parser to parse the specified token stream
259+
pub fn with_tokens_with_locations(mut self, tokens: Vec<TokenWithLocation>) -> Self {
260+
self.tokens = tokens;
261+
self.index = 0;
262+
self
120263
}
121264

122-
pub fn new_without_locations(tokens: Vec<Token>, dialect: &'a dyn Dialect) -> Self {
265+
/// Reset this parser state to parse the specified tokens
266+
pub fn with_tokens(self, tokens: Vec<Token>) -> Self {
123267
// Put in dummy locations
124268
let tokens_with_locations: Vec<TokenWithLocation> = tokens
125269
.into_iter()
@@ -128,49 +272,84 @@ impl<'a> Parser<'a> {
128272
location: Location { line: 0, column: 0 },
129273
})
130274
.collect();
131-
Parser::new_with_locations(tokens_with_locations, dialect)
275+
self.with_tokens_with_locations(tokens_with_locations)
132276
}
133277

134-
/// Parse the specified tokens
135-
pub fn new_with_locations(tokens: Vec<TokenWithLocation>, dialect: &'a dyn Dialect) -> Self {
136-
Parser {
137-
tokens,
138-
index: 0,
139-
dialect,
140-
}
278+
/// Tokenize the sql string and sets this [`Parser`]'s state to
279+
/// parse the resulting tokens
280+
///
281+
/// Returns an error if there was an error tokenizing the SQL string.
282+
///
283+
/// See example on [`Parser::new()`] for an example
284+
pub fn try_with_sql(self, sql: &str) -> Result<Self, ParserError> {
285+
debug!("Parsing sql '{}'...", sql);
286+
let mut tokenizer = Tokenizer::new(self.dialect, sql);
287+
let tokens = tokenizer.tokenize()?;
288+
Ok(self.with_tokens(tokens))
141289
}
142290

143-
/// Parse a SQL statement and produce an Abstract Syntax Tree (AST)
144-
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
145-
let mut tokenizer = Tokenizer::new(dialect, sql);
146-
let tokens = tokenizer.tokenize()?;
147-
let mut parser = Parser::new(tokens, dialect);
291+
/// Parse potentially multiple statements
292+
///
293+
/// Example
294+
/// ```
295+
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
296+
/// # fn main() -> Result<(), ParserError> {
297+
/// let dialect = GenericDialect{};
298+
/// let statements = Parser::new(&dialect)
299+
/// // Parse a SQL string with 2 separate statements
300+
/// .try_with_sql("SELECT * FROM foo; SELECT * FROM bar;")?
301+
/// .parse_statements()?;
302+
/// assert_eq!(statements.len(), 2);
303+
/// # Ok(())
304+
/// # }
305+
/// ```
306+
pub fn parse_statements(&mut self) -> Result<Vec<Statement>, ParserError> {
148307
let mut stmts = Vec::new();
149308
let mut expecting_statement_delimiter = false;
150-
debug!("Parsing sql '{}'...", sql);
151309
loop {
152310
// ignore empty statements (between successive statement delimiters)
153-
while parser.consume_token(&Token::SemiColon) {
311+
while self.consume_token(&Token::SemiColon) {
154312
expecting_statement_delimiter = false;
155313
}
156314

157-
if parser.peek_token() == Token::EOF {
315+
if self.peek_token() == Token::EOF {
158316
break;
159317
}
160318
if expecting_statement_delimiter {
161-
return parser.expected("end of statement", parser.peek_token());
319+
return self.expected("end of statement", self.peek_token());
162320
}
163321

164-
let statement = parser.parse_statement()?;
322+
let statement = self.parse_statement()?;
165323
stmts.push(statement);
166324
expecting_statement_delimiter = true;
167325
}
168326
Ok(stmts)
169327
}
170328

329+
/// Convience method to parse a string with one or more SQL
330+
/// statements into produce an Abstract Syntax Tree (AST).
331+
///
332+
/// Example
333+
/// ```
334+
/// # use sqlparser::{parser::{Parser, ParserError}, dialect::GenericDialect};
335+
/// # fn main() -> Result<(), ParserError> {
336+
/// let dialect = GenericDialect{};
337+
/// let statements = Parser::parse_sql(
338+
/// &dialect, "SELECT * FROM foo"
339+
/// )?;
340+
/// assert_eq!(statements.len(), 1);
341+
/// # Ok(())
342+
/// # }
343+
/// ```
344+
pub fn parse_sql(dialect: &dyn Dialect, sql: &str) -> Result<Vec<Statement>, ParserError> {
345+
Parser::new(dialect).try_with_sql(sql)?.parse_statements()
346+
}
347+
171348
/// Parse a single top-level statement (such as SELECT, INSERT, CREATE, etc.),
172349
/// stopping before the statement separator, if any.
173350
pub fn parse_statement(&mut self) -> Result<Statement, ParserError> {
351+
let _guard = self.recursion_counter.try_decrease()?;
352+
174353
// allow the dialect to override statement parsing
175354
if let Some(statement) = self.dialect.parse_statement(self) {
176355
return statement;
@@ -364,6 +543,7 @@ impl<'a> Parser<'a> {
364543

365544
/// Parse a new expression
366545
pub fn parse_expr(&mut self) -> Result<Expr, ParserError> {
546+
let _guard = self.recursion_counter.try_decrease()?;
367547
self.parse_subexpr(0)
368548
}
369549

@@ -4512,6 +4692,7 @@ impl<'a> Parser<'a> {
45124692
/// by `ORDER BY`. Unlike some other parse_... methods, this one doesn't
45134693
/// expect the initial keyword to be already consumed
45144694
pub fn parse_query(&mut self) -> Result<Query, ParserError> {
4695+
let _guard = self.recursion_counter.try_decrease()?;
45154696
let with = if self.parse_keyword(Keyword::WITH) {
45164697
Some(With {
45174698
recursive: self.parse_keyword(Keyword::RECURSIVE),

src/test_utils.rs

+2-4
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,6 @@ use core::fmt::Debug;
2929
use crate::ast::*;
3030
use crate::dialect::*;
3131
use crate::parser::{Parser, ParserError};
32-
use crate::tokenizer::Tokenizer;
3332

3433
/// Tests use the methods on this struct to invoke the parser on one or
3534
/// multiple dialects.
@@ -65,9 +64,8 @@ impl TestedDialects {
6564
F: Fn(&mut Parser) -> T,
6665
{
6766
self.one_of_identical_results(|dialect| {
68-
let mut tokenizer = Tokenizer::new(dialect, sql);
69-
let tokens = tokenizer.tokenize().unwrap();
70-
f(&mut Parser::new(tokens, dialect))
67+
let mut parser = Parser::new(dialect).try_with_sql(sql).unwrap();
68+
f(&mut parser)
7169
})
7270
}
7371

0 commit comments

Comments
 (0)