From 75162b273402dcbad2c848c24876a4b35ba3e653 Mon Sep 17 00:00:00 2001 From: Nikhil Benesch Date: Thu, 2 Jul 2020 02:18:19 -0400 Subject: [PATCH] Implement FromSql for tuples up to length 4 This makes it very ergonomic to decode the results of a query like SELECT (1, 'a') where (1, 'a') is returned as an anonymous record type. The big downside to this approach is that only built-in OIDs are supported, as there is no way to know ahead of time what OIDs will be returned, and so we'll only have metadata for the built-in OIDs lying around. --- postgres-types/src/lib.rs | 56 ++++++++++++++++++++++++++ tokio-postgres/tests/test/types/mod.rs | 52 ++++++++++++++++++++++++ 2 files changed, 108 insertions(+) diff --git a/postgres-types/src/lib.rs b/postgres-types/src/lib.rs index e78cedf4a..2b1ec1e00 100644 --- a/postgres-types/src/lib.rs +++ b/postgres-types/src/lib.rs @@ -617,6 +617,62 @@ impl<'a> FromSql<'a> for IpAddr { accepts!(INET); } +macro_rules! impl_from_sql_tuple { + ($n:expr; $($ty_ident:ident),*; $($var_ident:ident),*) => { + impl<'a, $($ty_ident),*> FromSql<'a> for ($($ty_ident,)*) + where + $($ty_ident: FromSql<'a>),* + { + fn from_sql( + _: &Type, + mut raw: &'a [u8], + ) -> Result<($($ty_ident,)*), Box> { + let num_fields = private::read_be_i32(&mut raw)?; + if num_fields as usize != $n { + return Err(format!( + "Postgres record field count does not match Rust tuple length: {} vs {}", + num_fields, + $n, + ).into()); + } + + $( + let oid = private::read_be_i32(&mut raw)? as u32; + let ty = match Type::from_oid(oid) { + None => { + return Err(format!( + "cannot decode OID {} inside of anonymous record", + oid, + ).into()); + } + Some(ty) if !$ty_ident::accepts(&ty) => { + return Err(Box::new(WrongType::new::<$ty_ident>(ty.clone()))); + } + Some(ty) => ty, + }; + let $var_ident = private::read_value(&ty, &mut raw)?; + )* + + Ok(($($var_ident,)*)) + } + + fn accepts(ty: &Type) -> bool { + match ty.kind() { + Kind::Pseudo => *ty == Type::RECORD, + Kind::Composite(fields) => fields.len() == $n, + _ => false, + } + } + } + }; +} + +impl_from_sql_tuple!(0; ; ); +impl_from_sql_tuple!(1; T0; v0); +impl_from_sql_tuple!(2; T0, T1; v0, v1); +impl_from_sql_tuple!(3; T0, T1, T2; v0, v1, v2); +impl_from_sql_tuple!(4; T0, T1, T2, T3; v0, v1, v2, v3); + /// An enum representing the nullability of a Postgres value. pub enum IsNull { /// The value is NULL. diff --git a/tokio-postgres/tests/test/types/mod.rs b/tokio-postgres/tests/test/types/mod.rs index 9f96019fe..54a5f85ce 100644 --- a/tokio-postgres/tests/test/types/mod.rs +++ b/tokio-postgres/tests/test/types/mod.rs @@ -547,6 +547,58 @@ async fn composite() { } } +#[tokio::test] +async fn tuples() { + let client = connect("user=postgres").await; + + let row = client.query_one("SELECT ROW()", &[]).await.unwrap(); + row.get::<_, ()>(0); + + let row = client.query_one("SELECT ROW(1)", &[]).await.unwrap(); + let val: (i32,) = row.get(0); + assert_eq!(val, (1,)); + + let row = client.query_one("SELECT (1, 'a')", &[]).await.unwrap(); + let val: (i32, String) = row.get(0); + assert_eq!(val, (1, "a".into())); + + let row = client.query_one("SELECT (1, (2, 3))", &[]).await.unwrap(); + let val: (i32, (i32, i32)) = row.get(0); + assert_eq!(val, (1, (2, 3))); + + let row = client.query_one("SELECT (1, 2)", &[]).await.unwrap(); + let err = row.try_get::<_, (i32, String)>(0).unwrap_err(); + match err.source() { + Some(e) if e.is::() => {} + _ => panic!("Unexpected error {:?}", err), + }; + + let row = client.query_one("SELECT (1, 2, 3)", &[]).await.unwrap(); + let err = row.try_get::<_, (i32, i32)>(0).unwrap_err(); + assert_eq!( + err.to_string(), + "error deserializing column 0: \ + Postgres record field count does not match Rust tuple length: 3 vs 2" + ); + + client + .batch_execute( + "CREATE TYPE pg_temp.simple AS ( + a int, + b text + )", + ) + .await + .unwrap(); + + let row = client + .query_one("SELECT (1, 'a')::simple", &[]) + .await + .unwrap(); + let val: (i32, String) = row.get(0); + assert_eq!(val, (1, "a".into())); +} + #[tokio::test] async fn enum_() { let client = connect("user=postgres").await;