Skip to content

chore(completions): add tree sitter query for table aliases #381

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Apr 30, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions crates/pgt_completions/src/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,8 @@ pub(crate) struct CompletionContext<'a> {
pub wrapping_statement_range: Option<tree_sitter::Range>,

pub mentioned_relations: HashMap<Option<String>, HashSet<String>>,

pub mentioned_table_aliases: HashMap<String, String>,
}

impl<'a> CompletionContext<'a> {
Expand All @@ -131,6 +133,7 @@ impl<'a> CompletionContext<'a> {
wrapping_statement_range: None,
is_invocation: false,
mentioned_relations: HashMap::new(),
mentioned_table_aliases: HashMap::new(),
};

ctx.gather_tree_context();
Expand All @@ -146,6 +149,7 @@ impl<'a> CompletionContext<'a> {
let mut executor = TreeSitterQueriesExecutor::new(self.tree.root_node(), sql);

executor.add_query_results::<queries::RelationMatch>();
executor.add_query_results::<queries::TableAliasMatch>();

for relation_match in executor.get_iter(stmt_range) {
match relation_match {
Expand All @@ -166,6 +170,13 @@ impl<'a> CompletionContext<'a> {
}
};
}

QueryResult::TableAliases(table_alias_match) => {
self.mentioned_table_aliases.insert(
table_alias_match.get_alias(sql),
table_alias_match.get_table(sql),
);
}
};
}
}
Expand Down
72 changes: 71 additions & 1 deletion crates/pgt_treesitter_queries/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,77 @@ impl<'a> Iterator for QueryResultIter<'a> {
#[cfg(test)]
mod tests {

use crate::{TreeSitterQueriesExecutor, queries::RelationMatch};
use crate::{
TreeSitterQueriesExecutor,
queries::{RelationMatch, TableAliasMatch},
};

#[test]
fn finds_all_table_aliases() {
let sql = r#"
select
*
from
(
select
something
from
public.cool_table pu
join private.cool_tableau pr on pu.id = pr.id
where
x = '123'
union
select
something_else
from
another_table puat
inner join private.another_tableau prat on puat.id = prat.id
union
select
x,
y
from
public.get_something_cool ()
) as cool
join users u on u.id = cool.something
where
col = 17;
"#;

let mut parser = tree_sitter::Parser::new();
parser.set_language(tree_sitter_sql::language()).unwrap();

let tree = parser.parse(sql, None).unwrap();

let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql);

executor.add_query_results::<TableAliasMatch>();

let results: Vec<&TableAliasMatch> = executor
.get_iter(None)
.filter_map(|q| q.try_into().ok())
.collect();

assert_eq!(results[0].get_schema(sql), Some("public".into()));
assert_eq!(results[0].get_table(sql), "cool_table");
assert_eq!(results[0].get_alias(sql), "pu");

assert_eq!(results[1].get_schema(sql), Some("private".into()));
assert_eq!(results[1].get_table(sql), "cool_tableau");
assert_eq!(results[1].get_alias(sql), "pr");

assert_eq!(results[2].get_schema(sql), None);
assert_eq!(results[2].get_table(sql), "another_table");
assert_eq!(results[2].get_alias(sql), "puat");

assert_eq!(results[3].get_schema(sql), Some("private".into()));
assert_eq!(results[3].get_table(sql), "another_tableau");
assert_eq!(results[3].get_alias(sql), "prat");

assert_eq!(results[4].get_schema(sql), None);
assert_eq!(results[4].get_table(sql), "users");
assert_eq!(results[4].get_alias(sql), "u");
}

#[test]
fn finds_all_relations_and_ignores_functions() {
Expand Down
10 changes: 9 additions & 1 deletion crates/pgt_treesitter_queries/src/queries/mod.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
mod relations;
mod table_aliases;

pub use relations::*;
pub use table_aliases::*;

#[derive(Debug)]
pub enum QueryResult<'a> {
Relation(RelationMatch<'a>),
TableAliases(TableAliasMatch<'a>),
}

impl QueryResult<'_> {
pub fn within_range(&self, range: &tree_sitter::Range) -> bool {
match self {
Self::Relation(rm) => {
QueryResult::Relation(rm) => {
let start = match rm.schema {
Some(s) => s.start_position(),
None => rm.table.start_position(),
Expand All @@ -20,6 +23,11 @@ impl QueryResult<'_> {

start >= range.start_point && end <= range.end_point
}
QueryResult::TableAliases(m) => {
let start = m.table.start_position();
let end = m.alias.end_position();
start >= range.start_point && end <= range.end_point
}
}
}
}
Expand Down
106 changes: 106 additions & 0 deletions crates/pgt_treesitter_queries/src/queries/table_aliases.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
use std::sync::LazyLock;

use crate::{Query, QueryResult};

use super::QueryTryFrom;

static TS_QUERY: LazyLock<tree_sitter::Query> = LazyLock::new(|| {
static QUERY_STR: &str = r#"
(relation
(object_reference
.
(identifier) @schema_or_table
"."?
(identifier)? @table
)
(keyword_as)?
(identifier) @alias
)
"#;
tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query")
});

#[derive(Debug)]
pub struct TableAliasMatch<'a> {
pub(crate) table: tree_sitter::Node<'a>,
pub(crate) alias: tree_sitter::Node<'a>,
pub(crate) schema: Option<tree_sitter::Node<'a>>,
}

impl TableAliasMatch<'_> {
pub fn get_alias(&self, sql: &str) -> String {
self.alias
.utf8_text(sql.as_bytes())
.expect("Failed to get alias from TableAliasMatch")
.to_string()
}

pub fn get_table(&self, sql: &str) -> String {
self.table
.utf8_text(sql.as_bytes())
.expect("Failed to get table from TableAliasMatch")
.to_string()
}

pub fn get_schema(&self, sql: &str) -> Option<String> {
self.schema.as_ref().map(|n| {
n.utf8_text(sql.as_bytes())
.expect("Failed to get table from TableAliasMatch")
.to_string()
})
}
}

impl<'a> TryFrom<&'a QueryResult<'a>> for &'a TableAliasMatch<'a> {
type Error = String;

fn try_from(q: &'a QueryResult<'a>) -> Result<Self, Self::Error> {
match q {
QueryResult::TableAliases(t) => Ok(t),

#[allow(unreachable_patterns)]
_ => Err("Invalid QueryResult type".into()),
}
}
}

impl<'a> QueryTryFrom<'a> for TableAliasMatch<'a> {
type Ref = &'a TableAliasMatch<'a>;
}

impl<'a> Query<'a> for TableAliasMatch<'a> {
fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec<crate::QueryResult<'a>> {
let mut cursor = tree_sitter::QueryCursor::new();

let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes());

let mut to_return = vec![];

for m in matches {
if m.captures.len() == 3 {
let schema = m.captures[0].node;
let table = m.captures[1].node;
let alias = m.captures[2].node;

to_return.push(QueryResult::TableAliases(TableAliasMatch {
table,
alias,
schema: Some(schema),
}));
}

if m.captures.len() == 2 {
let table = m.captures[0].node;
let alias = m.captures[1].node;

to_return.push(QueryResult::TableAliases(TableAliasMatch {
table,
alias,
schema: None,
}));
}
}

to_return
}
}