diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index b16fd21c..db21e498 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -115,6 +115,8 @@ pub(crate) struct CompletionContext<'a> { pub wrapping_statement_range: Option, pub mentioned_relations: HashMap, HashSet>, + + pub mentioned_table_aliases: HashMap, } impl<'a> CompletionContext<'a> { @@ -131,6 +133,7 @@ impl<'a> CompletionContext<'a> { wrapping_statement_range: None, is_invocation: false, mentioned_relations: HashMap::new(), + mentioned_table_aliases: HashMap::new(), }; ctx.gather_tree_context(); @@ -146,6 +149,7 @@ impl<'a> CompletionContext<'a> { let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql); executor.add_query_results::(); + executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { match relation_match { @@ -166,6 +170,13 @@ impl<'a> CompletionContext<'a> { } }; } + + QueryResult::TableAliases(table_alias_match) => { + self.mentioned_table_aliases.insert( + table_alias_match.get_alias(sql), + table_alias_match.get_table(sql), + ); + } }; } } diff --git a/crates/pgt_treesitter_queries/src/lib.rs b/crates/pgt_treesitter_queries/src/lib.rs index 7d2ba61b..8d1719b0 100644 --- a/crates/pgt_treesitter_queries/src/lib.rs +++ b/crates/pgt_treesitter_queries/src/lib.rs @@ -68,7 +68,77 @@ impl<'a> Iterator for QueryResultIter<'a> { #[cfg(test)] mod tests { - use crate::{TreeSitterQueriesExecutor, queries::RelationMatch}; + use crate::{ + TreeSitterQueriesExecutor, + queries::{RelationMatch, TableAliasMatch}, + }; + + #[test] + fn finds_all_table_aliases() { + let sql = r#" +select + * +from + ( + select + something + from + public.cool_table pu + join private.cool_tableau pr on pu.id = pr.id + where + x = '123' + union + select + something_else + from + another_table puat + inner join private.another_tableau prat on puat.id = prat.id + union + select + x, + y + from + public.get_something_cool () + ) as cool + join users u on u.id = cool.something +where + col = 17; +"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&TableAliasMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results[0].get_schema(sql), Some("public".into())); + assert_eq!(results[0].get_table(sql), "cool_table"); + assert_eq!(results[0].get_alias(sql), "pu"); + + assert_eq!(results[1].get_schema(sql), Some("private".into())); + assert_eq!(results[1].get_table(sql), "cool_tableau"); + assert_eq!(results[1].get_alias(sql), "pr"); + + assert_eq!(results[2].get_schema(sql), None); + assert_eq!(results[2].get_table(sql), "another_table"); + assert_eq!(results[2].get_alias(sql), "puat"); + + assert_eq!(results[3].get_schema(sql), Some("private".into())); + assert_eq!(results[3].get_table(sql), "another_tableau"); + assert_eq!(results[3].get_alias(sql), "prat"); + + assert_eq!(results[4].get_schema(sql), None); + assert_eq!(results[4].get_table(sql), "users"); + assert_eq!(results[4].get_alias(sql), "u"); + } #[test] fn finds_all_relations_and_ignores_functions() { diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 98b55e03..4e10ed60 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -1,16 +1,19 @@ mod relations; +mod table_aliases; pub use relations::*; +pub use table_aliases::*; #[derive(Debug)] pub enum QueryResult<'a> { Relation(RelationMatch<'a>), + TableAliases(TableAliasMatch<'a>), } impl QueryResult<'_> { pub fn within_range(&self, range: &tree_sitter::Range) -> bool { match self { - Self::Relation(rm) => { + QueryResult::Relation(rm) => { let start = match rm.schema { Some(s) => s.start_position(), None => rm.table.start_position(), @@ -20,6 +23,11 @@ impl QueryResult<'_> { start >= range.start_point && end <= range.end_point } + QueryResult::TableAliases(m) => { + let start = m.table.start_position(); + let end = m.alias.end_position(); + start >= range.start_point && end <= range.end_point + } } } } diff --git a/crates/pgt_treesitter_queries/src/queries/table_aliases.rs b/crates/pgt_treesitter_queries/src/queries/table_aliases.rs new file mode 100644 index 00000000..4297a218 --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/table_aliases.rs @@ -0,0 +1,106 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (relation + (object_reference + . + (identifier) @schema_or_table + "."? + (identifier)? @table + ) + (keyword_as)? + (identifier) @alias + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct TableAliasMatch<'a> { + pub(crate) table: tree_sitter::Node<'a>, + pub(crate) alias: tree_sitter::Node<'a>, + pub(crate) schema: Option>, +} + +impl TableAliasMatch<'_> { + pub fn get_alias(&self, sql: &str) -> String { + self.alias + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from TableAliasMatch") + .to_string() + } + + pub fn get_table(&self, sql: &str) -> String { + self.table + .utf8_text(sql.as_bytes()) + .expect("Failed to get table from TableAliasMatch") + .to_string() + } + + pub fn get_schema(&self, sql: &str) -> Option { + self.schema.as_ref().map(|n| { + n.utf8_text(sql.as_bytes()) + .expect("Failed to get table from TableAliasMatch") + .to_string() + }) + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a TableAliasMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::TableAliases(t) => Ok(t), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for TableAliasMatch<'a> { + type Ref = &'a TableAliasMatch<'a>; +} + +impl<'a> Query<'a> for TableAliasMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 3 { + let schema = m.captures[0].node; + let table = m.captures[1].node; + let alias = m.captures[2].node; + + to_return.push(QueryResult::TableAliases(TableAliasMatch { + table, + alias, + schema: Some(schema), + })); + } + + if m.captures.len() == 2 { + let table = m.captures[0].node; + let alias = m.captures[1].node; + + to_return.push(QueryResult::TableAliases(TableAliasMatch { + table, + alias, + schema: None, + })); + } + } + + to_return + } +}