Skip to content

Commit 98ae38f

Browse files
feat(completions): correctly infer (quoted) schema for tables & functions (#509)
1 parent 2f2d900 commit 98ae38f

File tree

6 files changed

+220
-23
lines changed

6 files changed

+220
-23
lines changed

crates/pgt_completions/src/providers/functions.rs

Lines changed: 72 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ use pgt_treesitter::TreesitterContext;
44
use crate::{
55
CompletionItemKind, CompletionText,
66
builder::{CompletionBuilder, PossibleCompletionItem},
7-
providers::helper::get_range_to_replace,
7+
providers::helper::{get_range_to_replace, node_text_surrounded_by_quotes, only_leading_quote},
88
relevance::{CompletionRelevanceData, filtering::CompletionFilter, scoring::CompletionScore},
99
};
1010

@@ -37,7 +37,7 @@ pub fn complete_functions<'a>(
3737
fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionText {
3838
let mut text = with_schema_or_alias(ctx, func.name.as_str(), Some(func.schema.as_str()));
3939

40-
let range = get_range_to_replace(ctx);
40+
let mut range = get_range_to_replace(ctx);
4141

4242
if ctx.is_invocation {
4343
CompletionText {
@@ -46,6 +46,11 @@ fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionTe
4646
is_snippet: false,
4747
}
4848
} else {
49+
if node_text_surrounded_by_quotes(ctx) && !only_leading_quote(ctx) {
50+
text.push('"');
51+
range = range.checked_expand_end(1.into()).unwrap_or(range);
52+
}
53+
4954
text.push('(');
5055

5156
let num_args = func.args.args.len();
@@ -68,6 +73,7 @@ fn get_completion_text(ctx: &TreesitterContext, func: &Function) -> CompletionTe
6873

6974
#[cfg(test)]
7075
mod tests {
76+
use pgt_text_size::TextRange;
7177
use sqlx::{Executor, PgPool};
7278

7379
use crate::{
@@ -294,4 +300,68 @@ mod tests {
294300
)
295301
.await;
296302
}
303+
304+
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
305+
async fn autocompletes_after_schema_in_quotes(pool: PgPool) {
306+
let setup = r#"
307+
create schema auth;
308+
309+
create or replace function auth.my_cool_foo()
310+
returns trigger
311+
language plpgsql
312+
security invoker
313+
as $$
314+
begin
315+
raise exception 'dont matter';
316+
end;
317+
$$;
318+
"#;
319+
320+
pool.execute(setup).await.unwrap();
321+
322+
assert_complete_results(
323+
format!(
324+
r#"select "auth".{}"#,
325+
QueryWithCursorPosition::cursor_marker()
326+
)
327+
.as_str(),
328+
vec![CompletionAssertion::CompletionTextAndRange(
329+
"my_cool_foo()".into(),
330+
TextRange::new(14.into(), 14.into()),
331+
)],
332+
None,
333+
&pool,
334+
)
335+
.await;
336+
337+
assert_complete_results(
338+
format!(
339+
r#"select "auth"."{}"#,
340+
QueryWithCursorPosition::cursor_marker()
341+
)
342+
.as_str(),
343+
vec![CompletionAssertion::CompletionTextAndRange(
344+
r#"my_cool_foo"()"#.into(),
345+
TextRange::new(15.into(), 15.into()),
346+
)],
347+
None,
348+
&pool,
349+
)
350+
.await;
351+
352+
assert_complete_results(
353+
format!(
354+
r#"select "auth"."{}""#,
355+
QueryWithCursorPosition::cursor_marker()
356+
)
357+
.as_str(),
358+
vec![CompletionAssertion::CompletionTextAndRange(
359+
r#"my_cool_foo"()"#.into(),
360+
TextRange::new(15.into(), 16.into()),
361+
)],
362+
None,
363+
&pool,
364+
)
365+
.await;
366+
}
297367
}

crates/pgt_completions/src/providers/helper.rs

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,12 @@ pub(crate) fn get_range_to_replace(ctx: &TreesitterContext) -> TextRange {
3434
}
3535
}
3636

37+
pub(crate) fn only_leading_quote(ctx: &TreesitterContext) -> bool {
38+
let node_under_cursor_txt = ctx.get_node_under_cursor_content().unwrap_or("".into());
39+
let node_under_cursor_txt = node_under_cursor_txt.as_str();
40+
is_sanitized_token_with_quote(node_under_cursor_txt)
41+
}
42+
3743
pub(crate) fn with_schema_or_alias(
3844
ctx: &TreesitterContext,
3945
item_name: &str,
@@ -42,21 +48,18 @@ pub(crate) fn with_schema_or_alias(
4248
let is_already_prefixed_with_schema_name = ctx.schema_or_alias_name.is_some();
4349

4450
let with_quotes = node_text_surrounded_by_quotes(ctx);
45-
46-
let node_under_cursor_txt = ctx.get_node_under_cursor_content().unwrap_or("".into());
47-
let node_under_cursor_txt = node_under_cursor_txt.as_str();
48-
let is_quote_sanitized = is_sanitized_token_with_quote(node_under_cursor_txt);
51+
let single_leading_quote = only_leading_quote(ctx);
4952

5053
if schema_or_alias_name.is_none_or(|s| s == "public") || is_already_prefixed_with_schema_name {
51-
if is_quote_sanitized {
54+
if single_leading_quote {
5255
format!(r#"{}""#, item_name)
5356
} else {
5457
item_name.to_string()
5558
}
5659
} else {
5760
let schema_or_als = schema_or_alias_name.unwrap();
5861

59-
if is_quote_sanitized {
62+
if single_leading_quote {
6063
format!(r#"{}"."{}""#, schema_or_als.replace('"', ""), item_name)
6164
} else if with_quotes {
6265
format!(r#"{}"."{}"#, schema_or_als.replace('"', ""), item_name)

crates/pgt_completions/src/providers/tables.rs

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ fn get_completion_text(ctx: &TreesitterContext, table: &Table) -> CompletionText
5858
#[cfg(test)]
5959
mod tests {
6060

61+
use pgt_text_size::TextRange;
6162
use sqlx::{Executor, PgPool};
6263

6364
use crate::{
@@ -569,4 +570,90 @@ mod tests {
569570
)
570571
.await;
571572
}
573+
574+
#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")]
575+
async fn after_quoted_schemas(pool: PgPool) {
576+
let setup = r#"
577+
create schema auth;
578+
579+
create table auth.users (
580+
uid serial primary key,
581+
name text not null,
582+
email text unique not null
583+
);
584+
585+
create table auth.posts (
586+
pid serial primary key,
587+
user_id int not null references auth.users(uid),
588+
title text not null,
589+
content text,
590+
created_at timestamp default now()
591+
);
592+
"#;
593+
594+
pool.execute(setup).await.unwrap();
595+
596+
assert_complete_results(
597+
format!(
598+
r#"select * from "auth".{}"#,
599+
QueryWithCursorPosition::cursor_marker()
600+
)
601+
.as_str(),
602+
vec![
603+
CompletionAssertion::CompletionTextAndRange(
604+
"posts".into(),
605+
TextRange::new(21.into(), 21.into()),
606+
),
607+
CompletionAssertion::CompletionTextAndRange(
608+
"users".into(),
609+
TextRange::new(21.into(), 21.into()),
610+
),
611+
],
612+
None,
613+
&pool,
614+
)
615+
.await;
616+
617+
assert_complete_results(
618+
format!(
619+
r#"select * from "auth"."{}""#,
620+
QueryWithCursorPosition::cursor_marker()
621+
)
622+
.as_str(),
623+
vec![
624+
CompletionAssertion::CompletionTextAndRange(
625+
"posts".into(),
626+
TextRange::new(22.into(), 22.into()),
627+
),
628+
CompletionAssertion::CompletionTextAndRange(
629+
"users".into(),
630+
TextRange::new(22.into(), 22.into()),
631+
),
632+
],
633+
None,
634+
&pool,
635+
)
636+
.await;
637+
638+
assert_complete_results(
639+
format!(
640+
r#"select * from "auth"."{}"#,
641+
QueryWithCursorPosition::cursor_marker()
642+
)
643+
.as_str(),
644+
vec![
645+
CompletionAssertion::CompletionTextAndRange(
646+
r#"posts""#.into(),
647+
TextRange::new(22.into(), 22.into()),
648+
),
649+
CompletionAssertion::CompletionTextAndRange(
650+
r#"users""#.into(),
651+
TextRange::new(22.into(), 22.into()),
652+
),
653+
],
654+
None,
655+
&pool,
656+
)
657+
.await;
658+
}
572659
}

crates/pgt_completions/src/relevance/filtering.rs

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -249,13 +249,13 @@ impl CompletionFilter<'_> {
249249
return Some(());
250250
}
251251

252-
let schema_or_alias = ctx.schema_or_alias_name.as_ref().unwrap();
252+
let schema_or_alias = ctx.schema_or_alias_name.as_ref().unwrap().replace('"', "");
253253

254254
let matches = match self.data {
255-
CompletionRelevanceData::Table(table) => &table.schema == schema_or_alias,
256-
CompletionRelevanceData::Function(f) => &f.schema == schema_or_alias,
255+
CompletionRelevanceData::Table(table) => table.schema == schema_or_alias,
256+
CompletionRelevanceData::Function(f) => f.schema == schema_or_alias,
257257
CompletionRelevanceData::Column(col) => ctx
258-
.get_mentioned_table_for_alias(schema_or_alias)
258+
.get_mentioned_table_for_alias(&schema_or_alias)
259259
.is_some_and(|t| t == &col.table_name),
260260

261261
// we should never allow schema suggestions if there already was one.

crates/pgt_completions/src/relevance/scoring.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -184,9 +184,10 @@ impl CompletionScore<'_> {
184184
}
185185

186186
fn check_matches_schema(&mut self, ctx: &TreesitterContext) {
187+
// TODO
187188
let schema_name = match ctx.schema_or_alias_name.as_ref() {
188189
None => return,
189-
Some(n) => n,
190+
Some(n) => n.replace('"', ""),
190191
};
191192

192193
let data_schema = match self.get_schema_name() {

crates/pgt_completions/src/sanitization.rs

Lines changed: 46 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,11 @@ where
9898
// we want to push spaces until we arrive at the cursor position.
9999
// we'll then add the SANITIZED_TOKEN
100100
if idx == cursor_pos {
101-
sql.push_str(SANITIZED_TOKEN);
101+
if opened_quote && has_uneven_quotes {
102+
sql.push_str(SANITIZED_TOKEN_WITH_QUOTE);
103+
} else {
104+
sql.push_str(SANITIZED_TOKEN);
105+
}
102106
} else {
103107
sql.push(' ');
104108
}
@@ -342,18 +346,50 @@ mod tests {
342346

343347
#[test]
344348
fn should_sanitize_with_opened_quotes() {
345-
// select "email", "| from "auth"."users";
346-
let input = r#"select "email", " from "auth"."users";"#;
347-
let position = TextSize::new(17);
349+
{
350+
// select "email", "| from "auth"."users";
351+
let input = r#"select "email", " from "auth"."users";"#;
352+
let position = TextSize::new(17);
348353

349-
let params = get_test_params(input, position);
354+
let params = get_test_params(input, position);
350355

351-
let sanitized = SanitizedCompletionParams::from(params);
356+
let sanitized = SanitizedCompletionParams::from(params);
352357

353-
assert_eq!(
354-
sanitized.text,
355-
r#"select "email", "REPLACED_TOKEN_WITH_QUOTE" from "auth"."users";"#
356-
);
358+
assert_eq!(
359+
sanitized.text,
360+
r#"select "email", "REPLACED_TOKEN_WITH_QUOTE" from "auth"."users";"#
361+
);
362+
}
363+
364+
{
365+
// select * from "auth"."|; <-- with semi
366+
let input = r#"select * from "auth".";"#;
367+
let position = TextSize::new(22);
368+
369+
let params = get_test_params(input, position);
370+
371+
let sanitized = SanitizedCompletionParams::from(params);
372+
373+
assert_eq!(
374+
sanitized.text,
375+
r#"select * from "auth"."REPLACED_TOKEN_WITH_QUOTE";"#
376+
);
377+
}
378+
379+
{
380+
// select * from "auth"."| <-- without semi
381+
let input = r#"select * from "auth".""#;
382+
let position = TextSize::new(22);
383+
384+
let params = get_test_params(input, position);
385+
386+
let sanitized = SanitizedCompletionParams::from(params);
387+
388+
assert_eq!(
389+
sanitized.text,
390+
r#"select * from "auth"."REPLACED_TOKEN_WITH_QUOTE""#
391+
);
392+
}
357393
}
358394

359395
#[test]

0 commit comments

Comments
 (0)