diff --git a/crates/pgt_lsp/tests/server.rs b/crates/pgt_lsp/tests/server.rs index 96ff566c..353e80ae 100644 --- a/crates/pgt_lsp/tests/server.rs +++ b/crates/pgt_lsp/tests/server.rs @@ -1678,3 +1678,84 @@ ALTER TABLE ONLY "public"."campaign_contact_list" Ok(()) } + +#[sqlx::test(migrator = "pgt_test_utils::MIGRATIONS")] +async fn test_plpgsql(test_db: PgPool) -> Result<()> { + let factory = ServerFactory::default(); + let mut fs = MemoryFileSystem::default(); + + let mut conf = PartialConfiguration::init(); + conf.merge_with(PartialConfiguration { + db: Some(PartialDatabaseConfiguration { + database: Some( + test_db + .connect_options() + .get_database() + .unwrap() + .to_string(), + ), + ..Default::default() + }), + ..Default::default() + }); + fs.insert( + url!("postgrestools.jsonc").to_file_path().unwrap(), + serde_json::to_string_pretty(&conf).unwrap(), + ); + + let (service, client) = factory + .create_with_fs(None, DynRef::Owned(Box::new(fs))) + .into_inner(); + + let (stream, sink) = client.split(); + let mut server = Server::new(service); + + let (sender, mut receiver) = channel(CHANNEL_BUFFER_SIZE); + let reader = tokio::spawn(client_handler(stream, sink, sender)); + + server.initialize().await?; + server.initialized().await?; + + server.load_configuration().await?; + + let initial_content = r#" +create function test_organisation_id () + returns setof text + language plpgsql + security invoker + as $$ + declre + v_organisation_id uuid; +begin + return next is(private.organisation_id(), v_organisation_id, 'should return organisation_id of token'); +end +$$; +"#; + + server.open_document(initial_content).await?; + + let notification = tokio::time::timeout(Duration::from_secs(5), async { + loop { + match receiver.next().await { + Some(ServerNotification::PublishDiagnostics(msg)) => { + if msg.diagnostics.iter().any(|d| { + d.message + .contains("Invalid statement: syntax error at or near \"declre\"") + }) { + return true; + } + } + _ => continue, + } + } + }) + .await + .is_ok(); + + assert!(notification, "expected diagnostics for unknown column"); + + server.shutdown().await?; + reader.abort(); + + Ok(()) +} diff --git a/crates/pgt_query_ext/src/diagnostics.rs b/crates/pgt_query_ext/src/diagnostics.rs index aa16db81..7e3f0a37 100644 --- a/crates/pgt_query_ext/src/diagnostics.rs +++ b/crates/pgt_query_ext/src/diagnostics.rs @@ -15,6 +15,16 @@ pub struct SyntaxDiagnostic { pub message: MessageAndDescription, } +impl SyntaxDiagnostic { + /// Create a new syntax diagnostic with the given message and optional span. + pub fn new(message: impl Into, span: Option) -> Self { + SyntaxDiagnostic { + span, + message: MessageAndDescription::from(message.into()), + } + } +} + impl From for SyntaxDiagnostic { fn from(err: pg_query::Error) -> Self { SyntaxDiagnostic { diff --git a/crates/pgt_query_ext/src/lib.rs b/crates/pgt_query_ext/src/lib.rs index a087ec60..5882a778 100644 --- a/crates/pgt_query_ext/src/lib.rs +++ b/crates/pgt_query_ext/src/lib.rs @@ -25,3 +25,38 @@ pub fn parse(sql: &str) -> Result { .ok_or_else(|| Error::Parse("Unable to find root node".to_string())) })? } + +/// This function parses a PL/pgSQL function. +/// +/// It expects the entire `CREATE FUNCTION` statement. +pub fn parse_plpgsql(sql: &str) -> Result<()> { + // we swallow the error until we have a proper binding + let _ = pg_query::parse_plpgsql(sql)?; + + Ok(()) +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_parse_plpgsql_err() { + let input = " +create function test_organisation_id () + returns setof text + language plpgsql + security invoker + as $$ + -- syntax error here + decare + v_organisation_id uuid; +begin + select 1; +end +$$; + "; + + assert!(parse_plpgsql(input).is_err()); + } +} diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index f7ace3c2..399f2ec6 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -53,6 +53,7 @@ mod async_helper; mod connection_key; mod connection_manager; pub(crate) mod document; +mod function_utils; mod migration; mod pg_query; mod schema_cache_manager; @@ -528,7 +529,7 @@ impl Workspace for WorkspaceServer { diagnostics.extend( doc.iter(SyncDiagnosticsMapper) - .flat_map(|(_id, range, ast, diag)| { + .flat_map(|(range, ast, diag)| { let mut errors: Vec = vec![]; if let Some(diag) = diag { @@ -560,9 +561,12 @@ impl Workspace for WorkspaceServer { }, ); + // adjust the span of the diagnostics to the statement (if it has one) + let span = d.location().span.map(|s| s + range.start()); + SDiagnostic::new( d.with_file_path(params.path.as_path().display().to_string()) - .with_file_span(range) + .with_file_span(span.unwrap_or(range)) .with_severity(severity), ) }) diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index f8ab639d..9d3700df 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -243,7 +243,6 @@ impl<'a> StatementMapper<'a> for AsyncDiagnosticsMapper { pub struct SyncDiagnosticsMapper; impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { type Output = ( - StatementId, TextRange, Option, Option, @@ -253,11 +252,18 @@ impl<'a> StatementMapper<'a> for SyncDiagnosticsMapper { let ast_result = parser.ast_db.get_or_cache_ast(&id); let (ast_option, diagnostics) = match &*ast_result { - Ok(node) => (Some(node.clone()), None), + Ok(node) => { + let plpgsql_result = parser.ast_db.get_or_cache_plpgsql_parse(&id); + if let Some(Err(diag)) = plpgsql_result { + (Some(node.clone()), Some(diag.clone())) + } else { + (Some(node.clone()), None) + } + } Err(diag) => (None, Some(diag.clone())), }; - (id.clone(), range, ast_option, diagnostics) + (range, ast_option, diagnostics) } } @@ -379,4 +385,274 @@ mod tests { assert_eq!(stmts.len(), 2); assert_eq!(stmts[1].2, "select $1 + $2;"); } + + #[test] + fn test_sync_diagnostics_mapper_plpgsql_syntax_error() { + let input = " +CREATE FUNCTION test_func() + RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + -- syntax error: missing semicolon and typo + DECLAR x integer + x := 10; +END; +$$;"; + + let d = Document::new(input.to_string(), 1); + let results = d.iter(SyncDiagnosticsMapper).collect::>(); + + assert_eq!(results.len(), 1); + let (_range, ast, diagnostic) = &results[0]; + + // Should have parsed the CREATE FUNCTION statement + assert!(ast.is_some()); + + // Should have a PL/pgSQL syntax error + assert!(diagnostic.is_some()); + assert_eq!( + format!("{:?}", diagnostic.as_ref().unwrap().message), + "Invalid statement: syntax error at or near \"DECLAR\"" + ); + } + + #[test] + fn test_sync_diagnostics_mapper_plpgsql_valid() { + let input = " +CREATE FUNCTION valid_func() + RETURNS integer + LANGUAGE plpgsql + AS $$ +DECLARE + x integer := 5; +BEGIN + RETURN x * 2; +END; +$$;"; + + let d = Document::new(input.to_string(), 1); + let results = d.iter(SyncDiagnosticsMapper).collect::>(); + + assert_eq!(results.len(), 1); + let (_range, ast, diagnostic) = &results[0]; + + // Should have parsed the CREATE FUNCTION statement + assert!(ast.is_some()); + + // Should NOT have any PL/pgSQL syntax errors + assert!(diagnostic.is_none()); + } + + #[test] + fn test_sync_diagnostics_mapper_plpgsql_caching() { + let input = " +CREATE FUNCTION cached_func() + RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + RAISE NOTICE 'Testing cache'; +END; +$$;"; + + let d = Document::new(input.to_string(), 1); + + let results1 = d.iter(SyncDiagnosticsMapper).collect::>(); + assert_eq!(results1.len(), 1); + assert!(results1[0].1.is_some()); + assert!(results1[0].2.is_none()); + + let results2 = d.iter(SyncDiagnosticsMapper).collect::>(); + assert_eq!(results2.len(), 1); + assert!(results2[0].1.is_some()); + assert!(results2[0].2.is_none()); + } + + #[test] + fn test_default_mapper() { + let input = "SELECT 1; INSERT INTO users VALUES (1);"; + let d = Document::new(input.to_string(), 1); + + let results = d.iter(DefaultMapper).collect::>(); + assert_eq!(results.len(), 2); + + assert_eq!(results[0].2, "SELECT 1;"); + assert_eq!(results[1].2, "INSERT INTO users VALUES (1);"); + + assert_eq!(results[0].1.start(), 0.into()); + assert_eq!(results[0].1.end(), 9.into()); + assert_eq!(results[1].1.start(), 10.into()); + assert_eq!(results[1].1.end(), 39.into()); + } + + #[test] + fn test_execute_statement_mapper() { + let input = "SELECT 1; INVALID SYNTAX HERE;"; + let d = Document::new(input.to_string(), 1); + + let results = d.iter(ExecuteStatementMapper).collect::>(); + assert_eq!(results.len(), 2); + + // First statement should parse successfully + assert_eq!(results[0].2, "SELECT 1;"); + assert!(results[0].3.is_some()); + + // Second statement should fail to parse + assert_eq!(results[1].2, "INVALID SYNTAX HERE;"); + assert!(results[1].3.is_none()); + } + + #[test] + fn test_async_diagnostics_mapper() { + let input = " +CREATE FUNCTION test_fn() RETURNS integer AS $$ +BEGIN + RETURN 42; +END; +$$ LANGUAGE plpgsql;"; + + let d = Document::new(input.to_string(), 1); + let results = d.iter(AsyncDiagnosticsMapper).collect::>(); + + assert_eq!(results.len(), 1); + let (_id, _range, ast, cst, sql_fn_sig) = &results[0]; + + // Should have both AST and CST + assert!(ast.is_some()); + assert_eq!(cst.root_node().kind(), "program"); + + // Should not have SQL function signature for top-level statement + assert!(sql_fn_sig.is_none()); + } + + #[test] + fn test_async_diagnostics_mapper_with_sql_function_body() { + let input = + "CREATE FUNCTION add(a int, b int) RETURNS int AS 'SELECT $1 + $2;' LANGUAGE sql;"; + let d = Document::new(input.to_string(), 1); + + let results = d.iter(AsyncDiagnosticsMapper).collect::>(); + assert_eq!(results.len(), 2); + + // Check the function body + let (_id, _range, ast, _cst, sql_fn_sig) = &results[1]; + assert_eq!(_id.content(), "SELECT $1 + $2;"); + assert!(ast.is_some()); + assert!(sql_fn_sig.is_some()); + + let sig = sql_fn_sig.as_ref().unwrap(); + assert_eq!(sig.name, "add"); + assert_eq!(sig.args.len(), 2); + assert_eq!(sig.args[0].name, Some("a".to_string())); + assert_eq!(sig.args[1].name, Some("b".to_string())); + } + + #[test] + fn test_get_completions_mapper() { + let input = "SELECT * FROM users;"; + let d = Document::new(input.to_string(), 1); + + let results = d.iter(GetCompletionsMapper).collect::>(); + assert_eq!(results.len(), 1); + + let (_id, _range, content, tree) = &results[0]; + assert_eq!(content, "SELECT * FROM users;"); + assert_eq!(tree.root_node().kind(), "program"); + } + + #[test] + fn test_get_completions_filter() { + let input = "SELECT * FROM users; INSERT INTO"; + let d = Document::new(input.to_string(), 1); + + // Test cursor at end of first statement (terminated with semicolon) + let filter1 = GetCompletionsFilter { + cursor_position: 20.into(), + }; + let results1 = d + .iter_with_filter(DefaultMapper, filter1) + .collect::>(); + assert_eq!(results1.len(), 0); // No completions after semicolon + + // Test cursor at end of second statement (not terminated) + let filter2 = GetCompletionsFilter { + cursor_position: 32.into(), + }; + let results2 = d + .iter_with_filter(DefaultMapper, filter2) + .collect::>(); + assert_eq!(results2.len(), 1); + assert_eq!(results2[0].2, "INSERT INTO"); + } + + #[test] + fn test_cursor_position_filter() { + let input = "SELECT 1; INSERT INTO users VALUES (1);"; + let d = Document::new(input.to_string(), 1); + + // Cursor in first statement + let filter1 = CursorPositionFilter::new(5.into()); + let results1 = d + .iter_with_filter(DefaultMapper, filter1) + .collect::>(); + assert_eq!(results1.len(), 1); + assert_eq!(results1[0].2, "SELECT 1;"); + + // Cursor in second statement + let filter2 = CursorPositionFilter::new(25.into()); + let results2 = d + .iter_with_filter(DefaultMapper, filter2) + .collect::>(); + assert_eq!(results2.len(), 1); + assert_eq!(results2[0].2, "INSERT INTO users VALUES (1);"); + } + + #[test] + fn test_id_filter() { + let input = "SELECT 1; SELECT 2;"; + let d = Document::new(input.to_string(), 1); + + // Get all statements first to get their IDs + let all_results = d.iter(DefaultMapper).collect::>(); + assert_eq!(all_results.len(), 2); + + // Filter by first statement ID + let filter = IdFilter::new(all_results[0].0.clone()); + let results = d + .iter_with_filter(DefaultMapper, filter) + .collect::>(); + assert_eq!(results.len(), 1); + assert_eq!(results[0].2, "SELECT 1;"); + } + + #[test] + fn test_no_filter() { + let input = "SELECT 1; SELECT 2; SELECT 3;"; + let d = Document::new(input.to_string(), 1); + + let results = d + .iter_with_filter(DefaultMapper, NoFilter) + .collect::>(); + assert_eq!(results.len(), 3); + } + + #[test] + fn test_find_method() { + let input = "SELECT 1; SELECT 2;"; + let d = Document::new(input.to_string(), 1); + + // Get all statements to get their IDs + let all_results = d.iter(DefaultMapper).collect::>(); + + // Find specific statement + let result = d.find(all_results[1].0.clone(), DefaultMapper); + assert!(result.is_some()); + assert_eq!(result.unwrap().2, "SELECT 2;"); + + // Try to find non-existent statement + let fake_id = StatementId::new("SELECT 3;"); + let result = d.find(fake_id, DefaultMapper); + assert!(result.is_none()); + } } diff --git a/crates/pgt_workspace/src/workspace/server/function_utils.rs b/crates/pgt_workspace/src/workspace/server/function_utils.rs new file mode 100644 index 00000000..cf02ceb1 --- /dev/null +++ b/crates/pgt_workspace/src/workspace/server/function_utils.rs @@ -0,0 +1,57 @@ +/// Helper function to find a specific option value from function options +pub fn find_option_value( + create_fn: &pgt_query_ext::protobuf::CreateFunctionStmt, + option_name: &str, +) -> Option { + create_fn + .options + .iter() + .filter_map(|opt_wrapper| opt_wrapper.node.as_ref()) + .find_map(|opt| { + if let pgt_query_ext::NodeEnum::DefElem(def_elem) = opt { + if def_elem.defname == option_name { + def_elem + .arg + .iter() + .filter_map(|arg_wrapper| arg_wrapper.node.as_ref()) + .find_map(|arg| { + if let pgt_query_ext::NodeEnum::String(s) = arg { + Some(s.sval.clone()) + } else if let pgt_query_ext::NodeEnum::List(l) = arg { + l.items.iter().find_map(|item_wrapper| { + if let Some(pgt_query_ext::NodeEnum::String(s)) = + item_wrapper.node.as_ref() + { + Some(s.sval.clone()) + } else { + None + } + }) + } else { + None + } + }) + } else { + None + } + } else { + None + } + }) +} + +pub fn parse_name(nodes: &[pgt_query_ext::protobuf::Node]) -> Option<(Option, String)> { + let names = nodes + .iter() + .map(|n| match &n.node { + Some(pgt_query_ext::NodeEnum::String(s)) => Some(s.sval.clone()), + _ => None, + }) + .collect::>(); + + match names.as_slice() { + [Some(schema), Some(name)] => Some((Some(schema.clone()), name.clone())), + [Some(name)] => Some((None, name.clone())), + _ => None, + } +} diff --git a/crates/pgt_workspace/src/workspace/server/pg_query.rs b/crates/pgt_workspace/src/workspace/server/pg_query.rs index 6f1fa2c1..ba471dfa 100644 --- a/crates/pgt_workspace/src/workspace/server/pg_query.rs +++ b/crates/pgt_workspace/src/workspace/server/pg_query.rs @@ -3,19 +3,25 @@ use std::sync::{Arc, Mutex}; use lru::LruCache; use pgt_query_ext::diagnostics::*; +use pgt_text_size::TextRange; +use super::function_utils::find_option_value; use super::statement_identifier::StatementId; const DEFAULT_CACHE_SIZE: usize = 1000; pub struct PgQueryStore { - db: Mutex>>>, + ast_db: Mutex>>>, + plpgsql_db: Mutex>>, } impl PgQueryStore { pub fn new() -> PgQueryStore { PgQueryStore { - db: Mutex::new(LruCache::new( + ast_db: Mutex::new(LruCache::new( + NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(), + )), + plpgsql_db: Mutex::new(LruCache::new( NonZeroUsize::new(DEFAULT_CACHE_SIZE).unwrap(), )), } @@ -25,7 +31,7 @@ impl PgQueryStore { &self, statement: &StatementId, ) -> Arc> { - let mut cache = self.db.lock().unwrap(); + let mut cache = self.ast_db.lock().unwrap(); if let Some(existing) = cache.get(statement) { return existing.clone(); @@ -35,4 +41,190 @@ impl PgQueryStore { cache.put(statement.clone(), r.clone()); r } + + pub fn get_or_cache_plpgsql_parse( + &self, + statement: &StatementId, + ) -> Option> { + let ast = self.get_or_cache_ast(statement); + + let create_fn = match ast.as_ref() { + Ok(pgt_query_ext::NodeEnum::CreateFunctionStmt(node)) => node, + _ => return None, + }; + + let language = find_option_value(create_fn, "language")?; + + if language != "plpgsql" { + return None; + } + + let mut cache = self.plpgsql_db.lock().unwrap(); + + if let Some(existing) = cache.get(statement) { + return Some(existing.clone()); + } + + let sql_body = find_option_value(create_fn, "as")?; + + let start = statement.content().find(&sql_body)?; + let end = start + sql_body.len(); + + let range = TextRange::new(start.try_into().unwrap(), end.try_into().unwrap()); + + let r = pgt_query_ext::parse_plpgsql(statement.content()) + .map_err(|err| SyntaxDiagnostic::new(err.to_string(), Some(range))); + cache.put(statement.clone(), r.clone()); + + Some(r) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_plpgsql_syntax_error() { + let input = " +create function test_organisation_id () + returns setof text + language plpgsql + security invoker + as $$ + -- syntax error here + delare + v_organisation_id uuid; +begin + return next is(private.organisation_id(), v_organisation_id, 'should return organisation_id of token'); +end +$$; +"; + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_plpgsql_parse(&StatementId::new(input)); + + assert!(matches!(res, Some(Err(_)))); + } + + #[test] + fn test_plpgsql_valid() { + let input = " +CREATE FUNCTION test_function() + RETURNS integer + LANGUAGE plpgsql + AS $$ +DECLARE + counter integer := 0; +BEGIN + counter := counter + 1; + RETURN counter; +END; +$$; +"; + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_plpgsql_parse(&StatementId::new(input)); + + assert!(matches!(res, Some(Ok(_)))); + } + + #[test] + fn test_non_plpgsql_function() { + let input = " +CREATE FUNCTION add_numbers(a integer, b integer) + RETURNS integer + LANGUAGE sql + AS $$ + SELECT a + b; + $$; +"; + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_plpgsql_parse(&StatementId::new(input)); + + assert!(res.is_none()); + } + + #[test] + fn test_non_function_statement() { + let input = "SELECT * FROM users WHERE id = 1;"; + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_plpgsql_parse(&StatementId::new(input)); + + assert!(res.is_none()); + } + + #[test] + fn test_cache_behavior() { + let input = " +CREATE FUNCTION cached_function() + RETURNS void + LANGUAGE plpgsql + AS $$ +BEGIN + RAISE NOTICE 'Hello from cache test'; +END; +$$; +"; + + let store = PgQueryStore::new(); + let statement_id = StatementId::new(input); + + // First call should parse + let res1 = store.get_or_cache_plpgsql_parse(&statement_id); + assert!(matches!(res1, Some(Ok(_)))); + + // Second call should return cached result + let res2 = store.get_or_cache_plpgsql_parse(&statement_id); + assert!(matches!(res2, Some(Ok(_)))); + } + + #[test] + fn test_plpgsql_with_complex_body() { + let input = " +CREATE FUNCTION complex_function(p_id integer) + RETURNS TABLE(id integer, name text, status boolean) + LANGUAGE plpgsql + AS $$ +DECLARE + v_count integer; + v_status boolean := true; +BEGIN + SELECT COUNT(*) INTO v_count FROM users WHERE user_id = p_id; + + IF v_count > 0 THEN + RETURN QUERY + SELECT u.id, u.name, v_status + FROM users u + WHERE u.user_id = p_id; + ELSE + RAISE EXCEPTION 'User not found'; + END IF; +END; +$$; +"; + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_plpgsql_parse(&StatementId::new(input)); + + assert!(matches!(res, Some(Ok(_)))); + } + + #[test] + fn test_invalid_ast() { + let input = "CREATE FUNCTION invalid syntax here"; + + let store = PgQueryStore::new(); + + let res = store.get_or_cache_plpgsql_parse(&StatementId::new(input)); + + assert!(res.is_none()); + } } diff --git a/crates/pgt_workspace/src/workspace/server/sql_function.rs b/crates/pgt_workspace/src/workspace/server/sql_function.rs index bc2c6c3b..6161dda7 100644 --- a/crates/pgt_workspace/src/workspace/server/sql_function.rs +++ b/crates/pgt_workspace/src/workspace/server/sql_function.rs @@ -1,5 +1,7 @@ use pgt_text_size::TextRange; +use super::function_utils::{find_option_value, parse_name}; + #[derive(Debug, Clone)] pub struct ArgType { pub schema: Option, @@ -106,64 +108,6 @@ pub fn get_sql_fn_body(ast: &pgt_query_ext::NodeEnum, content: &str) -> Option Option { - create_fn - .options - .iter() - .filter_map(|opt_wrapper| opt_wrapper.node.as_ref()) - .find_map(|opt| { - if let pgt_query_ext::NodeEnum::DefElem(def_elem) = opt { - if def_elem.defname == option_name { - def_elem - .arg - .iter() - .filter_map(|arg_wrapper| arg_wrapper.node.as_ref()) - .find_map(|arg| { - if let pgt_query_ext::NodeEnum::String(s) = arg { - Some(s.sval.clone()) - } else if let pgt_query_ext::NodeEnum::List(l) = arg { - l.items.iter().find_map(|item_wrapper| { - if let Some(pgt_query_ext::NodeEnum::String(s)) = - item_wrapper.node.as_ref() - { - Some(s.sval.clone()) - } else { - None - } - }) - } else { - None - } - }) - } else { - None - } - } else { - None - } - }) -} - -fn parse_name(nodes: &[pgt_query_ext::protobuf::Node]) -> Option<(Option, String)> { - let names = nodes - .iter() - .map(|n| match &n.node { - Some(pgt_query_ext::NodeEnum::String(s)) => Some(s.sval.clone()), - _ => None, - }) - .collect::>(); - - match names.as_slice() { - [Some(schema), Some(name)] => Some((Some(schema.clone()), name.clone())), - [Some(name)] => Some((None, name.clone())), - _ => None, - } -} - #[cfg(test)] mod tests { use super::*; diff --git a/test.sql b/test.sql index 88b7310d..97c8e639 100644 --- a/test.sql +++ b/test.sql @@ -9,3 +9,18 @@ from unknown_users; sel 1; + + + +create function test_organisation_id () + returns setof text + language plpgsql + security invoker + as $$ + declre + v_organisation_id uuid; +begin + return next is(private.organisation_id(), v_organisation_id, 'should return organisation_id of token'); +end +$$; +