diff --git a/postgres-derive-test/src/lib.rs b/postgres-derive-test/src/lib.rs index f0534f32c..a7523ddab 100644 --- a/postgres-derive-test/src/lib.rs +++ b/postgres-derive-test/src/lib.rs @@ -7,6 +7,7 @@ use std::fmt; mod composites; mod domains; mod enums; +mod records; mod transparent; pub fn test_type(conn: &mut Client, sql_type: &str, checks: &[(T, S)]) diff --git a/postgres-derive-test/src/records.rs b/postgres-derive-test/src/records.rs new file mode 100644 index 000000000..6fa1eaed9 --- /dev/null +++ b/postgres-derive-test/src/records.rs @@ -0,0 +1,209 @@ +use postgres::{Client, NoTls}; +use postgres_types::{FromSql, ToSql, WrongType}; +use std::error::Error; + +#[test] +fn basic() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + + let expected = InventoryItem { + name: "foobar".to_owned(), + supplier_id: 100, + price: Some(15.50), + }; + + let got = conn + .query_one("SELECT ('foobar', 100, 15.50::double precision)", &[]) + .unwrap() + .try_get::<_, InventoryItem>(0) + .unwrap(); + + assert_eq!(got, expected); +} + +#[test] +fn field_count_mismatch() { + #[derive(FromSql, Debug, PartialEq)] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + + let err = conn + .query_one("SELECT ('foobar', 100)", &[]) + .unwrap() + .try_get::<_, InventoryItem>(0) + .unwrap_err(); + err.source().unwrap().is::(); + + let err = conn + .query_one("SELECT ('foobar', 100, 15.50, 'extra')", &[]) + .unwrap() + .try_get::<_, InventoryItem>(0) + .unwrap_err(); + err.source().unwrap().is::(); +} + +#[test] +fn wrong_type() { + #[derive(FromSql, Debug, PartialEq)] + struct InventoryItem { + name: String, + supplier_id: i32, + price: Option, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + + let err = conn + .query_one("SELECT ('foobar', 'not_an_int', 15.50)", &[]) + .unwrap() + .try_get::<_, InventoryItem>(0) + .unwrap_err(); + assert!(err.source().unwrap().is::()); + + let err = conn + .query_one("SELECT (123, 100, 15.50)", &[]) + .unwrap() + .try_get::<_, InventoryItem>(0) + .unwrap_err(); + assert!(err.source().unwrap().is::()); +} + +#[test] +fn nested_structs() { + #[derive(FromSql, Debug, PartialEq)] + struct Address { + street: String, + city: Option, + } + + #[derive(FromSql, Debug, PartialEq)] + struct Person { + name: String, + age: Option, + address: Address, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + + let result: Person = conn + .query_one( + "SELECT ('John', 30, ROW('123 Main St', 'Springfield'))", + &[], + ) + .unwrap() + .get(0); + + let expected = Person { + name: "John".to_owned(), + age: Some(30), + address: Address { + street: "123 Main St".to_owned(), + city: Some("Springfield".to_owned()), + }, + }; + + assert_eq!(result, expected); +} + +#[test] +fn domains() { + #[derive(FromSql, Debug, PartialEq)] + struct SpecialId(i32); + + #[derive(FromSql, Debug, PartialEq)] + struct Person { + name: String, + age: Option, + id: SpecialId, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute("CREATE DOMAIN pg_temp.\"special_id\" AS integer;", &[]) + .unwrap(); + + let result: Person = conn + .query_one("SELECT ('John', 30, 42::special_id)", &[]) + .unwrap() + .get(0); + + let expected = Person { + name: "John".to_owned(), + age: Some(30), + id: SpecialId(42), + }; + + assert_eq!(result, expected); +} + +#[test] +fn enums() { + #[derive(FromSql, Debug, PartialEq)] + enum Employment { + Salaried, + Hourly, + Unemployed, + } + + #[derive(FromSql, Debug, PartialEq)] + struct Person { + name: String, + age: Option, + employment: Employment, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + conn.execute( + "CREATE TYPE pg_temp.employment AS ENUM ('Salaried', 'Hourly', 'Unemployed')", + &[], + ) + .unwrap(); + + let result: Person = conn + .query_one("SELECT ('John', 30, 'Hourly'::employment)", &[]) + .unwrap() + .get(0); + + let expected = Person { + name: "John".to_owned(), + age: Some(30), + employment: Employment::Hourly, + }; + + assert_eq!(result, expected); +} + +#[test] +fn generics() { + #[derive(FromSql, ToSql, Debug, PartialEq)] + struct GenericItem { + first: T, + second: U, + } + + let mut conn = Client::connect("user=postgres host=localhost port=5433", NoTls).unwrap(); + + let expected = GenericItem { + first: "test".to_owned(), + second: 42, + }; + + let got = conn + .query_one("SELECT ('test', 42)", &[]) + .unwrap() + .try_get::<_, GenericItem>(0) + .unwrap(); + + assert_eq!(got, expected); +} diff --git a/postgres-derive/src/accepts.rs b/postgres-derive/src/accepts.rs index a68538dcc..a91e03670 100644 --- a/postgres-derive/src/accepts.rs +++ b/postgres-derive/src/accepts.rs @@ -41,6 +41,9 @@ pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> Toke } } else { quote! { + if *type_ == postgres_types::Type::UNKNOWN { + return true; + } if type_.name() != #name { return false; } @@ -66,20 +69,47 @@ pub fn enum_body(name: &str, variants: &[Variant], allow_mismatch: bool) -> Toke } } -pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream { +pub fn composite_body_from_sql(name: &str, fields: &[Field]) -> TokenStream { let num_fields = fields.len(); - let trait_ = Ident::new(trait_, Span::call_site()); + let trait_ = Ident::new("FromSql", Span::call_site()); let traits = iter::repeat(&trait_); let field_names = fields.iter().map(|f| &f.name); let field_types = fields.iter().map(|f| &f.type_); quote! { - if type_.name() != #name { - return false; + match *type_.kind() { + ::postgres_types::Kind::Composite(ref fields) if type_.name() == #name => { + if fields.len() != #num_fields { + return false; + } + + fields.iter().all(|f| { + match f.name() { + #( + #field_names => { + <#field_types as ::postgres_types::#traits>::accepts(f.type_()) + } + )* + _ => false, + } + }) + }, + ::postgres_types::Kind::Pseudo => true, + _ => false, } + } +} +pub fn composite_body_to_sql(name: &str, fields: &[Field]) -> TokenStream { + let num_fields = fields.len(); + let trait_ = Ident::new("ToSql", Span::call_site()); + let traits = iter::repeat(&trait_); + let field_names = fields.iter().map(|f| &f.name); + let field_types = fields.iter().map(|f| &f.type_); + + quote! { match *type_.kind() { - ::postgres_types::Kind::Composite(ref fields) => { + ::postgres_types::Kind::Composite(ref fields) if type_.name() == #name => { if fields.len() != #num_fields { return false; } @@ -94,7 +124,7 @@ pub fn composite_body(name: &str, trait_: &str, fields: &[Field]) -> TokenStream _ => false, } }) - } + }, _ => false, } } diff --git a/postgres-derive/src/fromsql.rs b/postgres-derive/src/fromsql.rs index d3ac47f4f..0680d0b13 100644 --- a/postgres-derive/src/fromsql.rs +++ b/postgres-derive/src/fromsql.rs @@ -101,7 +101,7 @@ pub fn expand_derive_fromsql(input: DeriveInput) -> Result { .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( - accepts::composite_body(&name, "FromSql", &fields), + accepts::composite_body_from_sql(&name, &fields), composite_body(&input.ident, &fields), ) } @@ -171,6 +171,9 @@ fn domain_accepts_body(name: &str, field: &syn::Field) -> TokenStream { if <#ty as postgres_types::FromSql>::accepts(type_) { return true; } + if *type_ == postgres_types::Type::UNKNOWN { + return true; + } #normal_body } @@ -191,45 +194,92 @@ fn composite_body(ident: &Ident, fields: &[Field]) -> TokenStream { let field_names = &fields.iter().map(|f| &f.name).collect::>(); let field_idents = &fields.iter().map(|f| &f.ident).collect::>(); + let field_types = &fields.iter().map(|f| &f.type_).collect::>(); + let field_indices = (0..fields.len()).collect::>(); + let field_count = fields.len(); + quote! { - let fields = match *_type.kind() { - postgres_types::Kind::Composite(ref fields) => fields, - _ => unreachable!(), - }; + match *_type.kind() { + postgres_types::Kind::Composite(ref fields) => { + let mut buf = buf; + let num_fields = postgres_types::private::read_be_i32(&mut buf)?; + if num_fields as usize != fields.len() { + return std::result::Result::Err(std::convert::Into::into(format!( + "invalid field count: {} vs {}", + num_fields, + fields.len() + ))); + } - let mut buf = buf; - let num_fields = postgres_types::private::read_be_i32(&mut buf)?; - if num_fields as usize != fields.len() { - return std::result::Result::Err( - std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, fields.len()))); - } + #( + let mut #temp_vars = std::option::Option::None; + )* - #( - let mut #temp_vars = std::option::Option::None; - )* + for field in fields { + let oid = postgres_types::private::read_be_i32(&mut buf)? as u32; + if oid != field.type_().oid() { + return std::result::Result::Err(std::convert::Into::into("unexpected OID")); + } - for field in fields { - let oid = postgres_types::private::read_be_i32(&mut buf)? as u32; - if oid != field.type_().oid() { - return std::result::Result::Err(std::convert::Into::into("unexpected OID")); - } + match field.name() { + #( + #field_names => { + #temp_vars = std::option::Option::Some( + postgres_types::private::read_value(field.type_(), &mut buf)?, + ); + } + )* + _ => unreachable!(), + } + } + + std::result::Result::Ok(#ident { + #( + #field_idents: #temp_vars.unwrap(), + )* + }) + }, + postgres_types::Kind::Pseudo if *_type == postgres_types::Type::RECORD => { + let mut buf = buf; + let num_fields = postgres_types::private::read_be_i32(&mut buf)?; + if num_fields as usize != #field_count { + return std::result::Result::Err( + std::convert::Into::into(format!("invalid field count: {} vs {}", num_fields, #field_count))); + } - match field.name() { #( - #field_names => { - #temp_vars = std::option::Option::Some( - postgres_types::private::read_value(field.type_(), &mut buf)?); - } + let mut #temp_vars = std::option::Option::None; )* - _ => unreachable!(), - } - } - std::result::Result::Ok(#ident { - #( - #field_idents: #temp_vars.unwrap(), - )* - }) + for i in 0..num_fields { + let oid = postgres_types::private::read_be_i32(&mut buf)? as u32; + // FIXME: I see no other way to make this work with non-builtin types + // besides expanding the FromSql trait to also require passing the client + let ty = postgres_types::Type::from_oid(oid).unwrap_or(postgres_types::Type::UNKNOWN); + + match i as usize { + #( + #field_indices => { + if !<#field_types as postgres_types::FromSql>::accepts(&ty) { + return std::result::Result::Err(std::boxed::Box::new( + postgres_types::WrongType::new::<#field_types>(ty.clone()))); + } + #temp_vars = std::option::Option::Some( + postgres_types::private::read_value(&ty, &mut buf)?); + } + )* + _ => unreachable!(), + } + } + + std::result::Result::Ok(#ident { + #( + #field_idents: #temp_vars.unwrap(), + )* + }) + }, + _ => unreachable!(), + } } } diff --git a/postgres-derive/src/tosql.rs b/postgres-derive/src/tosql.rs index 81d4834bf..c34604e88 100644 --- a/postgres-derive/src/tosql.rs +++ b/postgres-derive/src/tosql.rs @@ -95,7 +95,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { .map(|field| Field::parse(field, overrides.rename_all)) .collect::, _>>()?; ( - accepts::composite_body(&name, "ToSql", &fields), + accepts::composite_body_to_sql(&name, &fields), composite_body(&fields), ) } @@ -112,7 +112,7 @@ pub fn expand_derive_tosql(input: DeriveInput) -> Result { let generics = append_generic_bound(input.generics.to_owned(), &new_tosql_bound()); let (impl_generics, ty_generics, where_clause) = generics.split_for_impl(); let out = quote! { - impl#impl_generics postgres_types::ToSql for #ident#ty_generics #where_clause { + impl #impl_generics postgres_types::ToSql for #ident #ty_generics #where_clause { fn to_sql(&self, _type: &postgres_types::Type, buf: &mut postgres_types::private::BytesMut) diff --git a/postgres-types/src/type_gen.rs b/postgres-types/src/type_gen.rs index a1bc3f85c..bc513e0b1 100644 --- a/postgres-types/src/type_gen.rs +++ b/postgres-types/src/type_gen.rs @@ -700,7 +700,7 @@ impl Inner { Inner::LanguageHandler => &Kind::Pseudo, Inner::Internal => &Kind::Pseudo, Inner::Anyelement => &Kind::Pseudo, - Inner::RecordArray => &Kind::Pseudo, + Inner::RecordArray => &Kind::Array(Type(Inner::Record)), Inner::Anynonarray => &Kind::Pseudo, Inner::TxidSnapshotArray => &Kind::Array(Type(Inner::TxidSnapshot)), Inner::Uuid => &Kind::Simple,