1- use std:: iter:: Peekable ;
2-
31use pgt_text_size:: { TextRange , TextSize } ;
2+ use std:: iter:: Peekable ;
43
54pub ( crate ) struct TokenNavigator {
65 tokens : Peekable < std:: vec:: IntoIter < WordWithIndex > > ,
@@ -101,73 +100,139 @@ impl WordWithIndex {
101100 }
102101}
103102
104- /// Note: A policy name within quotation marks will be considered a single word.
105- pub ( crate ) fn sql_to_words ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
106- let mut words = vec ! [ ] ;
107-
108- let mut start_of_word: Option < usize > = None ;
109- let mut current_word = String :: new ( ) ;
110- let mut in_quotation_marks = false ;
111-
112- for ( current_position, current_char) in sql. char_indices ( ) {
113- if ( current_char. is_ascii_whitespace ( ) || current_char == ';' )
114- && !current_word. is_empty ( )
115- && start_of_word. is_some ( )
116- && !in_quotation_marks
117- {
118- words. push ( WordWithIndex {
119- word : current_word,
120- start : start_of_word. unwrap ( ) ,
121- end : current_position,
122- } ) ;
123-
124- current_word = String :: new ( ) ;
125- start_of_word = None ;
126- } else if ( current_char. is_ascii_whitespace ( ) || current_char == ';' )
127- && current_word. is_empty ( )
128- {
129- // do nothing
130- } else if current_char == '"' && start_of_word. is_none ( ) {
131- in_quotation_marks = true ;
132- current_word. push ( current_char) ;
133- start_of_word = Some ( current_position) ;
134- } else if current_char == '"' && start_of_word. is_some ( ) {
135- current_word. push ( current_char) ;
136- in_quotation_marks = false ;
137- } else if start_of_word. is_some ( ) {
138- current_word. push ( current_char)
103+ pub ( crate ) struct SubStatementParser {
104+ start_of_word : Option < usize > ,
105+ current_word : String ,
106+ in_quotation_marks : bool ,
107+ is_fn_call : bool ,
108+ words : Vec < WordWithIndex > ,
109+ }
110+
111+ impl SubStatementParser {
112+ pub ( crate ) fn parse ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
113+ let mut parser = SubStatementParser {
114+ start_of_word : None ,
115+ current_word : String :: new ( ) ,
116+ in_quotation_marks : false ,
117+ is_fn_call : false ,
118+ words : vec ! [ ] ,
119+ } ;
120+
121+ parser. collect_words ( sql) ;
122+
123+ if parser. in_quotation_marks {
124+ Err ( "String was not closed properly." . into ( ) )
139125 } else {
140- start_of_word = Some ( current_position) ;
141- current_word. push ( current_char) ;
126+ Ok ( parser. words )
142127 }
143128 }
144129
145- if let Some ( start_of_word) = start_of_word {
146- if !current_word. is_empty ( ) {
147- words. push ( WordWithIndex {
148- word : current_word,
149- start : start_of_word,
150- end : sql. len ( ) ,
151- } ) ;
130+ pub fn collect_words ( & mut self , sql : & str ) {
131+ for ( pos, c) in sql. char_indices ( ) {
132+ match c {
133+ '"' => {
134+ if !self . has_started_word ( ) {
135+ self . in_quotation_marks = true ;
136+ self . add_char ( c) ;
137+ self . start_word ( pos) ;
138+ } else {
139+ self . in_quotation_marks = false ;
140+ self . add_char ( c) ;
141+ }
142+ }
143+
144+ '(' => {
145+ if !self . has_started_word ( ) {
146+ self . push_char_as_word ( c, pos) ;
147+ } else {
148+ self . add_char ( c) ;
149+ self . is_fn_call = true ;
150+ }
151+ }
152+
153+ ')' => {
154+ if self . is_fn_call {
155+ self . add_char ( c) ;
156+ self . is_fn_call = false ;
157+ } else {
158+ if self . has_started_word ( ) {
159+ self . push_word ( pos) ;
160+ }
161+ self . push_char_as_word ( c, pos) ;
162+ }
163+ }
164+
165+ _ => {
166+ if c. is_ascii_whitespace ( ) || c == ';' {
167+ if self . in_quotation_marks {
168+ self . add_char ( c) ;
169+ } else if !self . is_empty ( ) && self . has_started_word ( ) {
170+ self . push_word ( pos) ;
171+ }
172+ } else if self . has_started_word ( ) {
173+ self . add_char ( c) ;
174+ } else {
175+ self . start_word ( pos) ;
176+ self . add_char ( c)
177+ }
178+ }
179+ }
180+ }
181+
182+ if self . has_started_word ( ) && !self . is_empty ( ) {
183+ self . push_word ( sql. len ( ) )
152184 }
153185 }
154186
155- if in_quotation_marks {
156- Err ( "String was not closed properly." . into ( ) )
157- } else {
158- Ok ( words)
187+ fn is_empty ( & self ) -> bool {
188+ self . current_word . is_empty ( )
189+ }
190+
191+ fn add_char ( & mut self , c : char ) {
192+ self . current_word . push ( c)
193+ }
194+
195+ fn start_word ( & mut self , pos : usize ) {
196+ self . start_of_word = Some ( pos) ;
197+ }
198+
199+ fn has_started_word ( & self ) -> bool {
200+ self . start_of_word . is_some ( )
201+ }
202+
203+ fn push_char_as_word ( & mut self , c : char , pos : usize ) {
204+ self . words . push ( WordWithIndex {
205+ word : String :: from ( c) ,
206+ start : pos,
207+ end : pos + 1 ,
208+ } ) ;
209+ }
210+
211+ fn push_word ( & mut self , current_position : usize ) {
212+ self . words . push ( WordWithIndex {
213+ word : self . current_word . clone ( ) ,
214+ start : self . start_of_word . unwrap ( ) ,
215+ end : current_position,
216+ } ) ;
217+ self . current_word = String :: new ( ) ;
218+ self . start_of_word = None ;
159219 }
160220}
161221
222+ /// Note: A policy name within quotation marks will be considered a single word.
223+ pub ( crate ) fn sql_to_words ( sql : & str ) -> Result < Vec < WordWithIndex > , String > {
224+ SubStatementParser :: parse ( sql)
225+ }
226+
162227#[ cfg( test) ]
163228mod tests {
164- use crate :: context:: base_parser:: { WordWithIndex , sql_to_words} ;
229+ use crate :: context:: base_parser:: { SubStatementParser , WordWithIndex , sql_to_words} ;
165230
166231 #[ test]
167232 fn determines_positions_correctly ( ) {
168- let query = "\n create policy \" my cool pol\" \n \t on auth.users\n \t as permissive\n \t for select\n \t \t to public\n \t \t using (true );" . to_string ( ) ;
233+ let query = "\n create policy \" my cool pol\" \n \t on auth.users\n \t as permissive\n \t for select\n \t \t to public\n \t \t using (auth.uid() );" . to_string ( ) ;
169234
170- let words = sql_to_words ( query. as_str ( ) ) . unwrap ( ) ;
235+ let words = SubStatementParser :: parse ( query. as_str ( ) ) . unwrap ( ) ;
171236
172237 assert_eq ! ( words[ 0 ] , to_word( "create" , 1 , 7 ) ) ;
173238 assert_eq ! ( words[ 1 ] , to_word( "policy" , 8 , 14 ) ) ;
@@ -181,7 +246,9 @@ mod tests {
181246 assert_eq ! ( words[ 9 ] , to_word( "to" , 73 , 75 ) ) ;
182247 assert_eq ! ( words[ 10 ] , to_word( "public" , 78 , 84 ) ) ;
183248 assert_eq ! ( words[ 11 ] , to_word( "using" , 87 , 92 ) ) ;
184- assert_eq ! ( words[ 12 ] , to_word( "(true)" , 93 , 99 ) ) ;
249+ assert_eq ! ( words[ 12 ] , to_word( "(" , 93 , 94 ) ) ;
250+ assert_eq ! ( words[ 13 ] , to_word( "auth.uid()" , 94 , 104 ) ) ;
251+ assert_eq ! ( words[ 14 ] , to_word( ")" , 104 , 105 ) ) ;
185252 }
186253
187254 #[ test]
0 commit comments