1
1
#[ cfg( feature = "derive-visitor" ) ]
2
2
mod test {
3
- use derive_visitor:: { visitor_enter_fn_mut, DriveMut } ;
3
+ use derive_visitor:: { visitor_enter_fn_mut, DriveMut , Drive , Visitor } ;
4
+ use sqlparser:: ast;
4
5
use sqlparser:: ast:: TableFactor :: Table ;
5
- use sqlparser:: ast:: { ObjectName , TableFactor } ;
6
+ use sqlparser:: ast:: { Ident , Join , JoinConstraint , JoinOperator , ObjectName , TableFactor , Value } ;
6
7
use sqlparser:: dialect:: GenericDialect ;
7
8
use sqlparser:: parser:: Parser ;
8
9
@@ -21,4 +22,61 @@ mod test {
21
22
} ) ) ;
22
23
assert_eq ! ( ast[ 0 ] . to_string( ) , "SELECT x FROM a JOIN b" ) ;
23
24
}
25
+
26
+ #[ test]
27
+ fn test_visitor_add_table_to_join ( ) {
28
+ let sql = "select a,b from t1" ;
29
+ // Table t1 has changed and its fields are now split between tables t1 and t2
30
+ let mut ast = Parser :: parse_sql ( & GenericDialect , sql) . unwrap ( ) ;
31
+ ast[ 0 ] . drive_mut ( & mut visitor_enter_fn_mut (
32
+ |table : & mut ast:: TableWithJoins | {
33
+ let has_t1 = std:: iter:: once ( & table. relation )
34
+ . chain ( table. joins . iter ( ) . map ( |j| & j. relation ) )
35
+ . any ( |r| {
36
+ matches ! ( r, TableFactor :: Table { name: ObjectName ( idents) , ..}
37
+ if idents[ 0 ] . value == "t1" )
38
+ } ) ;
39
+ if has_t1 {
40
+ table. joins . push ( Join {
41
+ relation : Table {
42
+ name : ObjectName ( vec ! [ Ident :: from( "t2" ) ] ) ,
43
+ alias : None ,
44
+ args : None ,
45
+ with_hints : vec ! [ ] ,
46
+ } ,
47
+ join_operator : JoinOperator :: Inner ( JoinConstraint :: Using ( vec ! [
48
+ Ident :: from( "t_id" ) ,
49
+ ] ) ) ,
50
+ } ) ;
51
+ }
52
+ } ,
53
+ ) ) ;
54
+ assert_eq ! (
55
+ ast[ 0 ] . to_string( ) ,
56
+ "SELECT a, b FROM t1 JOIN t2 USING(t_id)"
57
+ ) ;
58
+ }
59
+
60
+ #[ test]
61
+ fn test_immutable_visitor_count_parameters ( ) {
62
+ let sql = "select a, b+$x from t1 where c=$y" ;
63
+ // Table t1 has changed and its fields are now split between tables t1 and t2
64
+ let ast = Parser :: parse_sql ( & GenericDialect , sql) . unwrap ( ) ;
65
+
66
+ /// Counts the placeholder in an SQL statement
67
+ #[ derive( Visitor , Default ) ]
68
+ #[ visitor( Value ( enter) ) ]
69
+ struct PlaceholderCounter ( usize ) ;
70
+ impl PlaceholderCounter {
71
+ fn enter_value ( & mut self , value : & Value ) {
72
+ if let Value :: Placeholder ( _) = value {
73
+ self . 0 += 1 ;
74
+ }
75
+ }
76
+ }
77
+
78
+ let mut counter = PlaceholderCounter ( 0 ) ;
79
+ ast[ 0 ] . drive ( & mut counter) ;
80
+ assert_eq ! ( counter. 0 , 2 , "There are 2 placeholders in the query" ) ;
81
+ }
24
82
}
0 commit comments