diff --git a/crates/pgt_completions/src/context.rs b/crates/pgt_completions/src/context.rs index d96d0d53..a17cafa2 100644 --- a/crates/pgt_completions/src/context.rs +++ b/crates/pgt_completions/src/context.rs @@ -8,7 +8,7 @@ use pgt_treesitter_queries::{ use crate::sanitization::SanitizedCompletionParams; -#[derive(Debug, PartialEq, Eq)] +#[derive(Debug, PartialEq, Eq, Hash)] pub enum WrappingClause<'a> { Select, Where, @@ -26,6 +26,12 @@ pub(crate) enum NodeText<'a> { Original(&'a str), } +#[derive(PartialEq, Eq, Hash, Debug)] +pub(crate) struct MentionedColumn { + pub(crate) column: String, + pub(crate) alias: Option, +} + /// We can map a few nodes, such as the "update" node, to actual SQL clauses. /// That gives us a lot of insight for completions. /// Other nodes, such as the "relation" node, gives us less but still @@ -108,8 +114,8 @@ pub(crate) struct CompletionContext<'a> { pub is_in_error_node: bool, pub mentioned_relations: HashMap, HashSet>, - pub mentioned_table_aliases: HashMap, + pub mentioned_columns: HashMap>, HashSet>, } impl<'a> CompletionContext<'a> { @@ -127,6 +133,7 @@ impl<'a> CompletionContext<'a> { is_invocation: false, mentioned_relations: HashMap::new(), mentioned_table_aliases: HashMap::new(), + mentioned_columns: HashMap::new(), is_in_error_node: false, }; @@ -144,6 +151,7 @@ impl<'a> CompletionContext<'a> { executor.add_query_results::(); executor.add_query_results::(); + executor.add_query_results::(); for relation_match in executor.get_iter(stmt_range) { match relation_match { @@ -151,26 +159,38 @@ impl<'a> CompletionContext<'a> { let schema_name = r.get_schema(sql); let table_name = r.get_table(sql); - let current = self.mentioned_relations.get_mut(&schema_name); - - match current { - Some(c) => { - c.insert(table_name); - } - None => { - let mut new = HashSet::new(); - new.insert(table_name); - self.mentioned_relations.insert(schema_name, new); - } - }; + if let Some(c) = self.mentioned_relations.get_mut(&schema_name) { + c.insert(table_name); + } else { + let mut new = HashSet::new(); + new.insert(table_name); + self.mentioned_relations.insert(schema_name, new); + } } - QueryResult::TableAliases(table_alias_match) => { self.mentioned_table_aliases.insert( table_alias_match.get_alias(sql), table_alias_match.get_table(sql), ); } + QueryResult::SelectClauseColumns(c) => { + let mentioned = MentionedColumn { + column: c.get_column(sql), + alias: c.get_alias(sql), + }; + + if let Some(cols) = self + .mentioned_columns + .get_mut(&Some(WrappingClause::Select)) + { + cols.insert(mentioned); + } else { + let mut new = HashSet::new(); + new.insert(mentioned); + self.mentioned_columns + .insert(Some(WrappingClause::Select), new); + } + } }; } } diff --git a/crates/pgt_completions/src/providers/columns.rs b/crates/pgt_completions/src/providers/columns.rs index 331c4416..8109ba83 100644 --- a/crates/pgt_completions/src/providers/columns.rs +++ b/crates/pgt_completions/src/providers/columns.rs @@ -484,4 +484,93 @@ mod tests { ) .await; } + + #[tokio::test] + async fn prefers_not_mentioned_columns() { + let setup = r#" + create schema auth; + + create table public.one ( + id serial primary key, + a text, + b text, + z text + ); + + create table public.two ( + id serial primary key, + c text, + d text, + e text + ); + "#; + + assert_complete_results( + format!( + "select {} from public.one o join public.two on o.id = t.id;", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("a".to_string()), + CompletionAssertion::Label("b".to_string()), + CompletionAssertion::Label("c".to_string()), + CompletionAssertion::Label("d".to_string()), + CompletionAssertion::Label("e".to_string()), + ], + setup, + ) + .await; + + // "a" is already mentioned, so it jumps down + assert_complete_results( + format!( + "select a, {} from public.one o join public.two on o.id = t.id;", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::Label("b".to_string()), + CompletionAssertion::Label("c".to_string()), + CompletionAssertion::Label("d".to_string()), + CompletionAssertion::Label("e".to_string()), + CompletionAssertion::Label("id".to_string()), + CompletionAssertion::Label("z".to_string()), + CompletionAssertion::Label("a".to_string()), + ], + setup, + ) + .await; + + // "id" of table one is mentioned, but table two isn't – + // its priority stays up + assert_complete_results( + format!( + "select o.id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;", + CURSOR_POS + ) + .as_str(), + vec![ + CompletionAssertion::LabelAndDesc( + "id".to_string(), + "Table: public.two".to_string(), + ), + CompletionAssertion::Label("z".to_string()), + ], + setup, + ) + .await; + + // "id" is ambiguous, so both "id" columns are lowered in priority + assert_complete_results( + format!( + "select id, a, b, c, d, e, {} from public.one o join public.two on o.id = t.id;", + CURSOR_POS + ) + .as_str(), + vec![CompletionAssertion::Label("z".to_string())], + setup, + ) + .await; + } } diff --git a/crates/pgt_completions/src/providers/triggers.rs b/crates/pgt_completions/src/providers/triggers.rs new file mode 100644 index 00000000..6bc04deb --- /dev/null +++ b/crates/pgt_completions/src/providers/triggers.rs @@ -0,0 +1,169 @@ +use crate::{ + CompletionItemKind, + builder::{CompletionBuilder, PossibleCompletionItem}, + context::CompletionContext, + relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore}, +}; + +use super::helper::get_completion_text_with_schema_or_alias; + +pub fn complete_functions<'a>(ctx: &'a CompletionContext, builder: &mut CompletionBuilder<'a>) { + let available_functions = &ctx.schema_cache.functions; + + for func in available_functions { + let relevance = CompletionRelevanceData::Function(func); + + let item = PossibleCompletionItem { + label: func.name.clone(), + score: CompletionScore::from(relevance.clone()), + filter: CompletionFilter::from(relevance), + description: format!("Schema: {}", func.schema), + kind: CompletionItemKind::Function, + completion_text: get_completion_text_with_schema_or_alias( + ctx, + &func.name, + &func.schema, + ), + }; + + builder.add_item(item); + } +} + +#[cfg(test)] +mod tests { + use crate::{ + CompletionItem, CompletionItemKind, complete, + test_helper::{CURSOR_POS, get_test_deps, get_test_params}, + }; + + #[tokio::test] + async fn completes_fn() { + let setup = r#" + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!("select coo{}", CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); + let results = complete(params); + + let CompletionItem { label, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + } + + #[tokio::test] + async fn prefers_fn_if_invocation() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select * from coo{}()"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + assert_eq!(kind, CompletionItemKind::Function); + } + + #[tokio::test] + async fn prefers_fn_in_select_clause() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select coo{}"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + assert_eq!(kind, CompletionItemKind::Function); + } + + #[tokio::test] + async fn prefers_function_in_from_clause_if_invocation() { + let setup = r#" + create table coos ( + id serial primary key, + name text + ); + + create or replace function cool() + returns trigger + language plpgsql + security invoker + as $$ + begin + raise exception 'dont matter'; + end; + $$; + "#; + + let query = format!(r#"select * from coo{}()"#, CURSOR_POS); + + let (tree, cache) = get_test_deps(setup, query.as_str().into()).await; + let params = get_test_params(&tree, &cache, query.as_str().into()); + let results = complete(params); + + let CompletionItem { label, kind, .. } = results + .into_iter() + .next() + .expect("Should return at least one completion item"); + + assert_eq!(label, "cool"); + assert_eq!(kind, CompletionItemKind::Function); + } +} diff --git a/crates/pgt_completions/src/relevance/scoring.rs b/crates/pgt_completions/src/relevance/scoring.rs index 71c01023..b0b0bf63 100644 --- a/crates/pgt_completions/src/relevance/scoring.rs +++ b/crates/pgt_completions/src/relevance/scoring.rs @@ -32,6 +32,7 @@ impl CompletionScore<'_> { self.check_matching_clause_type(ctx); self.check_matching_wrapping_node(ctx); self.check_relations_in_stmt(ctx); + self.check_columns_in_stmt(ctx); } fn check_matches_query_input(&mut self, ctx: &CompletionContext) { @@ -235,4 +236,40 @@ impl CompletionScore<'_> { self.score += 2; } } + + fn check_columns_in_stmt(&mut self, ctx: &CompletionContext) { + if let CompletionRelevanceData::Column(column) = self.data { + /* + * Columns can be mentioned in one of two ways: + * + * 1) With an alias: `select u.id`. + * If the currently investigated suggestion item is "id" of the "users" table, + * we want to check + * a) whether the name of the column matches. + * b) whether we know which table is aliased by "u" (if we don't, we ignore the alias). + * c) whether the aliased table matches the currently investigated suggestion item's table. + * + * 2) Without an alias: `select id`. + * In that case, we only check whether the mentioned column fits our currently investigated + * suggestion item's name. + * + */ + if ctx + .mentioned_columns + .get(&ctx.wrapping_clause_type) + .is_some_and(|set| { + set.iter().any(|mentioned| match mentioned.alias.as_ref() { + Some(als) => { + let aliased_table = ctx.mentioned_table_aliases.get(als.as_str()); + column.name == mentioned.column + && aliased_table.is_none_or(|t| t == &column.table_name) + } + None => mentioned.column == column.name, + }) + }) + { + self.score -= 10; + } + } + } } diff --git a/crates/pgt_completions/src/test_helper.rs b/crates/pgt_completions/src/test_helper.rs index a6b57c55..937c11af 100644 --- a/crates/pgt_completions/src/test_helper.rs +++ b/crates/pgt_completions/src/test_helper.rs @@ -146,6 +146,7 @@ mod tests { pub(crate) enum CompletionAssertion { Label(String), LabelAndKind(String, CompletionItemKind), + LabelAndDesc(String, String), LabelNotExists(String), KindNotExists(CompletionItemKind), } @@ -186,6 +187,18 @@ impl CompletionAssertion { kind ); } + CompletionAssertion::LabelAndDesc(label, desc) => { + assert_eq!( + &item.label, label, + "Expected label to be {}, but got {}", + label, &item.label + ); + assert_eq!( + &item.description, desc, + "Expected desc to be {}, but got {}", + desc, &item.description + ); + } } } } @@ -202,7 +215,9 @@ pub(crate) async fn assert_complete_results( let (not_existing, existing): (Vec, Vec) = assertions.into_iter().partition(|a| match a { CompletionAssertion::LabelNotExists(_) | CompletionAssertion::KindNotExists(_) => true, - CompletionAssertion::Label(_) | CompletionAssertion::LabelAndKind(_, _) => false, + CompletionAssertion::Label(_) + | CompletionAssertion::LabelAndKind(_, _) + | CompletionAssertion::LabelAndDesc(_, _) => false, }); assert!( diff --git a/crates/pgt_treesitter_queries/src/queries/mod.rs b/crates/pgt_treesitter_queries/src/queries/mod.rs index 4e10ed60..e02d675b 100644 --- a/crates/pgt_treesitter_queries/src/queries/mod.rs +++ b/crates/pgt_treesitter_queries/src/queries/mod.rs @@ -1,13 +1,16 @@ mod relations; +mod select_columns; mod table_aliases; pub use relations::*; +pub use select_columns::*; pub use table_aliases::*; #[derive(Debug)] pub enum QueryResult<'a> { Relation(RelationMatch<'a>), TableAliases(TableAliasMatch<'a>), + SelectClauseColumns(SelectColumnMatch<'a>), } impl QueryResult<'_> { @@ -28,6 +31,16 @@ impl QueryResult<'_> { let end = m.alias.end_position(); start >= range.start_point && end <= range.end_point } + Self::SelectClauseColumns(cm) => { + let start = match cm.alias { + Some(n) => n.start_position(), + None => cm.column.start_position(), + }; + + let end = cm.column.end_position(); + + start >= range.start_point && end <= range.end_point + } } } } diff --git a/crates/pgt_treesitter_queries/src/queries/select_columns.rs b/crates/pgt_treesitter_queries/src/queries/select_columns.rs new file mode 100644 index 00000000..00b6977d --- /dev/null +++ b/crates/pgt_treesitter_queries/src/queries/select_columns.rs @@ -0,0 +1,172 @@ +use std::sync::LazyLock; + +use crate::{Query, QueryResult}; + +use super::QueryTryFrom; + +static TS_QUERY: LazyLock = LazyLock::new(|| { + static QUERY_STR: &str = r#" + (select_expression + (term + (field + (object_reference)? @alias + "."? + (identifier) @column + ) + ) + ","? + ) +"#; + tree_sitter::Query::new(tree_sitter_sql::language(), QUERY_STR).expect("Invalid TS Query") +}); + +#[derive(Debug)] +pub struct SelectColumnMatch<'a> { + pub(crate) alias: Option>, + pub(crate) column: tree_sitter::Node<'a>, +} + +impl SelectColumnMatch<'_> { + pub fn get_alias(&self, sql: &str) -> Option { + let str = self + .alias + .as_ref()? + .utf8_text(sql.as_bytes()) + .expect("Failed to get alias from ColumnMatch"); + + Some(str.to_string()) + } + + pub fn get_column(&self, sql: &str) -> String { + self.column + .utf8_text(sql.as_bytes()) + .expect("Failed to get column from ColumnMatch") + .to_string() + } +} + +impl<'a> TryFrom<&'a QueryResult<'a>> for &'a SelectColumnMatch<'a> { + type Error = String; + + fn try_from(q: &'a QueryResult<'a>) -> Result { + match q { + QueryResult::SelectClauseColumns(c) => Ok(c), + + #[allow(unreachable_patterns)] + _ => Err("Invalid QueryResult type".into()), + } + } +} + +impl<'a> QueryTryFrom<'a> for SelectColumnMatch<'a> { + type Ref = &'a SelectColumnMatch<'a>; +} + +impl<'a> Query<'a> for SelectColumnMatch<'a> { + fn execute(root_node: tree_sitter::Node<'a>, stmt: &'a str) -> Vec> { + let mut cursor = tree_sitter::QueryCursor::new(); + + let matches = cursor.matches(&TS_QUERY, root_node, stmt.as_bytes()); + + let mut to_return = vec![]; + + for m in matches { + if m.captures.len() == 1 { + let capture = m.captures[0].node; + to_return.push(QueryResult::SelectClauseColumns(SelectColumnMatch { + alias: None, + column: capture, + })); + } + + if m.captures.len() == 2 { + let alias = m.captures[0].node; + let column = m.captures[1].node; + + to_return.push(QueryResult::SelectClauseColumns(SelectColumnMatch { + alias: Some(alias), + column, + })); + } + } + + to_return + } +} + +#[cfg(test)] +mod tests { + use crate::TreeSitterQueriesExecutor; + + use super::SelectColumnMatch; + + #[test] + fn finds_all_columns() { + let sql = r#"select aud, id, email from auth.users;"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&SelectColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results[0].get_alias(sql), None); + assert_eq!(results[0].get_column(sql), "aud"); + + assert_eq!(results[1].get_alias(sql), None); + assert_eq!(results[1].get_column(sql), "id"); + + assert_eq!(results[2].get_alias(sql), None); + assert_eq!(results[2].get_column(sql), "email"); + } + + #[test] + fn finds_columns_with_aliases() { + let sql = r#" +select + u.id, + u.email, + cs.user_settings, + cs.client_id +from + auth.users u + join public.client_settings cs + on u.id = cs.user_id; + +"#; + + let mut parser = tree_sitter::Parser::new(); + parser.set_language(tree_sitter_sql::language()).unwrap(); + + let tree = parser.parse(sql, None).unwrap(); + + let mut executor = TreeSitterQueriesExecutor::new(tree.root_node(), sql); + + executor.add_query_results::(); + + let results: Vec<&SelectColumnMatch> = executor + .get_iter(None) + .filter_map(|q| q.try_into().ok()) + .collect(); + + assert_eq!(results[0].get_alias(sql), Some("u".into())); + assert_eq!(results[0].get_column(sql), "id"); + + assert_eq!(results[1].get_alias(sql), Some("u".into())); + assert_eq!(results[1].get_column(sql), "email"); + + assert_eq!(results[2].get_alias(sql), Some("cs".into())); + assert_eq!(results[2].get_column(sql), "user_settings"); + + assert_eq!(results[3].get_alias(sql), Some("cs".into())); + assert_eq!(results[3].get_column(sql), "client_id"); + } +}