Skip to content

Commit 0b5c12e

Browse files
committed
add tests
1 parent c33120a commit 0b5c12e

File tree

1 file changed

+60
-2
lines changed

1 file changed

+60
-2
lines changed

tests/sqlparser_visitor.rs

Lines changed: 60 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
#[cfg(feature = "derive-visitor")]
22
mod test {
3-
use derive_visitor::{visitor_enter_fn_mut, DriveMut};
3+
use derive_visitor::{visitor_enter_fn_mut, DriveMut, Drive, Visitor};
4+
use sqlparser::ast;
45
use sqlparser::ast::TableFactor::Table;
5-
use sqlparser::ast::{ObjectName, TableFactor};
6+
use sqlparser::ast::{Ident, Join, JoinConstraint, JoinOperator, ObjectName, TableFactor, Value};
67
use sqlparser::dialect::GenericDialect;
78
use sqlparser::parser::Parser;
89

@@ -21,4 +22,61 @@ mod test {
2122
}));
2223
assert_eq!(ast[0].to_string(), "SELECT x FROM a JOIN b");
2324
}
25+
26+
#[test]
27+
fn test_visitor_add_table_to_join() {
28+
let sql = "select a,b from t1";
29+
// Table t1 has changed and its fields are now split between tables t1 and t2
30+
let mut ast = Parser::parse_sql(&GenericDialect, sql).unwrap();
31+
ast[0].drive_mut(&mut visitor_enter_fn_mut(
32+
|table: &mut ast::TableWithJoins| {
33+
let has_t1 = std::iter::once(&table.relation)
34+
.chain(table.joins.iter().map(|j| &j.relation))
35+
.any(|r| {
36+
matches!(r, TableFactor::Table {name: ObjectName(idents), ..}
37+
if idents[0].value == "t1")
38+
});
39+
if has_t1 {
40+
table.joins.push(Join {
41+
relation: Table {
42+
name: ObjectName(vec![Ident::from("t2")]),
43+
alias: None,
44+
args: None,
45+
with_hints: vec![],
46+
},
47+
join_operator: JoinOperator::Inner(JoinConstraint::Using(vec![
48+
Ident::from("t_id"),
49+
])),
50+
});
51+
}
52+
},
53+
));
54+
assert_eq!(
55+
ast[0].to_string(),
56+
"SELECT a, b FROM t1 JOIN t2 USING(t_id)"
57+
);
58+
}
59+
60+
#[test]
61+
fn test_immutable_visitor_count_parameters() {
62+
let sql = "select a, b+$x from t1 where c=$y";
63+
// Table t1 has changed and its fields are now split between tables t1 and t2
64+
let ast = Parser::parse_sql(&GenericDialect, sql).unwrap();
65+
66+
/// Counts the placeholder in an SQL statement
67+
#[derive(Visitor, Default)]
68+
#[visitor(Value(enter))]
69+
struct PlaceholderCounter(usize);
70+
impl PlaceholderCounter {
71+
fn enter_value(&mut self, value: &Value) {
72+
if let Value::Placeholder(_) = value {
73+
self.0 += 1;
74+
}
75+
}
76+
}
77+
78+
let mut counter = PlaceholderCounter(0);
79+
ast[0].drive(&mut counter);
80+
assert_eq!(counter.0, 2, "There are 2 placeholders in the query");
81+
}
2482
}

0 commit comments

Comments
 (0)