diff --git a/Cargo.toml b/Cargo.toml index d5f77c6..6d0680c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -9,11 +9,13 @@ repository = "https://github.com/sfackler/rust-postgres-range" [features] with-chrono-0_4 = ["chrono-04", "postgres-types/with-chrono-0_4"] +with-rust_decimal-1 = ["rust_decimal-1"] [dependencies] postgres-protocol = "0.6" postgres-types = "0.2" chrono-04 = { version = "0.4", package = "chrono", optional = true, default-features = false } +rust_decimal-1 = { version = "1.32.0", package = "rust_decimal", optional = true, default-features = false, features = ["db-tokio-postgres"] } [dev-dependencies] postgres = "0.19" diff --git a/src/impls.rs b/src/impls.rs index 19615e1..c8b2db4 100644 --- a/src/impls.rs +++ b/src/impls.rs @@ -1,7 +1,7 @@ -use std::error::Error; -use postgres_types::{FromSql, IsNull, Kind, ToSql, Type}; -use postgres_types::private::BytesMut; use postgres_protocol::{self as protocol, types}; +use postgres_types::private::BytesMut; +use postgres_types::{FromSql, IsNull, Kind, ToSql, Type}; +use std::error::Error; use crate::{BoundSided, BoundType, Normalizable, Range, RangeBound}; @@ -33,7 +33,10 @@ where } } -fn bound_from_sql<'a, T, S>(bound: types::RangeBound>, ty: &Type) -> Result>, Box> +fn bound_from_sql<'a, T, S>( + bound: types::RangeBound>, + ty: &Type, +) -> Result>, Box> where T: PartialOrd + Normalizable + FromSql<'a>, S: BoundSided, @@ -61,7 +64,11 @@ impl ToSql for Range where T: PartialOrd + Normalizable + ToSql, { - fn to_sql(&self, ty: &Type, buf: &mut BytesMut) -> Result> { + fn to_sql( + &self, + ty: &Type, + buf: &mut BytesMut, + ) -> Result> { let element_type = match *ty.kind() { Kind::Range(ref ty) => ty, _ => panic!("unexpected type {:?}", ty), @@ -90,7 +97,11 @@ where to_sql_checked!(); } -fn bound_to_sql(bound: Option<&RangeBound>, ty: &Type, buf: &mut BytesMut) -> Result, Box> +fn bound_to_sql( + bound: Option<&RangeBound>, + ty: &Type, + buf: &mut BytesMut, +) -> Result, Box> where S: BoundSided, T: ToSql, @@ -115,10 +126,12 @@ where mod test { use std::fmt; - use postgres::{Client, NoTls}; - use postgres::types::{FromSql, ToSql}; #[cfg(feature = "with-chrono-0_4")] - use chrono_04::{TimeZone, Utc, Duration}; + use chrono_04::{Duration, TimeZone, Utc}; + use postgres::types::{FromSql, ToSql}; + use postgres::{Client, NoTls}; + #[cfg(feature = "with-rust_decimal-1")] + use rust_decimal_1::Decimal; macro_rules! test_range { ($name:expr, $t:ty, $low:expr, $low_str:expr, $high:expr, $high_str:expr) => ({ @@ -141,21 +154,33 @@ mod test { }) } - fn test_type(sql_type: &str, checks: &[(T, S)]) - where for<'a> - T: Sync + PartialEq + FromSql<'a> + ToSql, - S: fmt::Display + where + for<'a> T: Sync + PartialEq + FromSql<'a> + ToSql, + S: fmt::Display, { let mut conn = Client::connect("postgres://postgres@localhost", NoTls).unwrap(); for &(ref val, ref repr) in checks { - let stmt = conn.prepare(&*format!("SELECT {}::{}", *repr, sql_type)) + let stmt = conn + .prepare(&*format!("SELECT {}::{}", *repr, sql_type)) .unwrap(); - let result = conn.query(&stmt, &[]).unwrap().iter().next().unwrap().get(0); + let result = conn + .query(&stmt, &[]) + .unwrap() + .iter() + .next() + .unwrap() + .get(0); assert!(val == &result); let stmt = conn.prepare(&*format!("SELECT $1::{}", sql_type)).unwrap(); - let result = conn.query(&stmt, &[val]).unwrap().iter().next().unwrap().get(0); + let result = conn + .query(&stmt, &[val]) + .unwrap() + .iter() + .next() + .unwrap() + .get(0); assert!(val == &result); } } @@ -170,19 +195,41 @@ mod test { test_range!("INT8RANGE", i64, 100i64, "100", 200i64, "200") } + #[test] + #[cfg(feature = "with-rust_decimal-1")] + fn test_numrange_params() { + let low = Decimal::new(202, 2); + let high = Decimal::new(202, 1); + test_range!("NUMRANGE", Decimal, low, "2.02", high, "20.2"); + } + #[test] #[cfg(feature = "with-chrono-0_4")] fn test_tsrange_params() { - let low = Utc.timestamp(0, 0); + let low = Utc.timestamp_opt(0, 0).unwrap(); let high = low + Duration::days(10); - test_range!("TSRANGE", NaiveDateTime, low.naive_utc(), "1970-01-01", high.naive_utc(), "1970-01-11"); + test_range!( + "TSRANGE", + NaiveDateTime, + low.naive_utc(), + "1970-01-01", + high.naive_utc(), + "1970-01-11" + ); } #[test] #[cfg(feature = "with-chrono-0_4")] fn test_tstzrange_params() { - let low = Utc.timestamp(0, 0); + let low = Utc.timestamp_opt(0, 0).unwrap(); let high = low + Duration::days(10); - test_range!("TSTZRANGE", DateTime, low, "1970-01-01", high, "1970-01-11"); + test_range!( + "TSTZRANGE", + DateTime, + low, + "1970-01-01", + high, + "1970-01-11" + ); } } diff --git a/src/lib.rs b/src/lib.rs index a1d68a5..02893f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -8,6 +8,9 @@ extern crate postgres_types; #[cfg(feature = "with-chrono-0_4")] mod chrono_04; +#[cfg(feature = "with-rust_decimal-1")] +mod rust_decimal_1; + use std::cmp::Ordering; use std::fmt; use std::i32; @@ -54,56 +57,60 @@ use InnerRange::{Empty, Normal}; /// } #[macro_export] macro_rules! range { - (empty) => ($crate::Range::empty()); - ('(',; ')') => ($crate::Range::new(None, None)); - ('(', $h:expr; ')') => ( - $crate::Range::new(None, - Some($crate::RangeBound::new($h, - $crate::BoundType::Exclusive))) - ); - ('(', $h:expr; ']') => ( - $crate::Range::new(None, - Some($crate::RangeBound::new($h, - $crate::BoundType::Inclusive))) - ); - ('(' $l:expr,; ')') => ( + (empty) => { + $crate::Range::empty() + }; + ('(',; ')') => { + $crate::Range::new(None, None) + }; + ('(', $h:expr; ')') => { + $crate::Range::new( + None, + Some($crate::RangeBound::new($h, $crate::BoundType::Exclusive)), + ) + }; + ('(', $h:expr; ']') => { + $crate::Range::new( + None, + Some($crate::RangeBound::new($h, $crate::BoundType::Inclusive)), + ) + }; + ('(' $l:expr,; ')') => { $crate::Range::new( - Some($crate::RangeBound::new($l, - $crate::BoundType::Exclusive)), None) - ); - ('[' $l:expr,; ')') => ( + Some($crate::RangeBound::new($l, $crate::BoundType::Exclusive)), + None, + ) + }; + ('[' $l:expr,; ')') => { $crate::Range::new( - Some($crate::RangeBound::new($l, - $crate::BoundType::Inclusive)), None) - ); - ('(' $l:expr, $h:expr; ')') => ( + Some($crate::RangeBound::new($l, $crate::BoundType::Inclusive)), + None, + ) + }; + ('(' $l:expr, $h:expr; ')') => { $crate::Range::new( - Some($crate::RangeBound::new($l, - $crate::BoundType::Exclusive)), - Some($crate::RangeBound::new($h, - $crate::BoundType::Exclusive))) - ); - ('(' $l:expr, $h:expr; ']') => ( + Some($crate::RangeBound::new($l, $crate::BoundType::Exclusive)), + Some($crate::RangeBound::new($h, $crate::BoundType::Exclusive)), + ) + }; + ('(' $l:expr, $h:expr; ']') => { $crate::Range::new( - Some($crate::RangeBound::new($l, - $crate::BoundType::Exclusive)), - Some($crate::RangeBound::new($h, - $crate::BoundType::Inclusive))) - ); - ('[' $l:expr, $h:expr; ')') => ( + Some($crate::RangeBound::new($l, $crate::BoundType::Exclusive)), + Some($crate::RangeBound::new($h, $crate::BoundType::Inclusive)), + ) + }; + ('[' $l:expr, $h:expr; ')') => { $crate::Range::new( - Some($crate::RangeBound::new($l, - $crate::BoundType::Inclusive)), - Some($crate::RangeBound::new($h, - $crate::BoundType::Exclusive))) - ); - ('[' $l:expr, $h:expr; ']') => ( + Some($crate::RangeBound::new($l, $crate::BoundType::Inclusive)), + Some($crate::RangeBound::new($h, $crate::BoundType::Exclusive)), + ) + }; + ('[' $l:expr, $h:expr; ']') => { $crate::Range::new( - Some($crate::RangeBound::new($l, - $crate::BoundType::Inclusive)), - Some($crate::RangeBound::new($h, - $crate::BoundType::Inclusive))) - ) + Some($crate::RangeBound::new($l, $crate::BoundType::Inclusive)), + Some($crate::RangeBound::new($h, $crate::BoundType::Inclusive)), + ) + }; } mod impls; @@ -124,9 +131,12 @@ pub trait Normalizable: Sized { } macro_rules! bounded_normalizable { - ($t:ident) => ( + ($t:ident) => { impl Normalizable for $t { - fn normalize(bound: RangeBound) -> RangeBound where S: BoundSided { + fn normalize(bound: RangeBound) -> RangeBound + where + S: BoundSided, + { match (::side(), bound.type_) { (Upper, Inclusive) => { assert!(bound.value != $t::MAX); @@ -136,11 +146,11 @@ macro_rules! bounded_normalizable { assert!(bound.value != $t::MAX); RangeBound::new(bound.value + 1, Inclusive) } - _ => bound + _ => bound, } } } - ) + }; } bounded_normalizable!(i32); @@ -289,8 +299,10 @@ where other.type_, self.value.partial_cmp(&other.value), ) { - (Upper, Exclusive, Inclusive, Some(Ordering::Equal)) | (Lower, Inclusive, Exclusive, Some(Ordering::Equal)) => Some(Ordering::Less), - (Upper, Inclusive, Exclusive, Some(Ordering::Equal)) | (Lower, Exclusive, Inclusive, Some(Ordering::Equal)) => Some(Ordering::Greater), + (Upper, Exclusive, Inclusive, Some(Ordering::Equal)) + | (Lower, Inclusive, Exclusive, Some(Ordering::Equal)) => Some(Ordering::Less), + (Upper, Inclusive, Exclusive, Some(Ordering::Equal)) + | (Lower, Exclusive, Inclusive, Some(Ordering::Equal)) => Some(Ordering::Greater), (_, _, _, cmp) => cmp, } } @@ -308,8 +320,10 @@ where other.type_, self.value.cmp(&other.value), ) { - (Upper, Exclusive, Inclusive, Ordering::Equal) | (Lower, Inclusive, Exclusive, Ordering::Equal) => Ordering::Less, - (Upper, Inclusive, Exclusive, Ordering::Equal) | (Lower, Exclusive, Inclusive, Ordering::Equal) => Ordering::Greater, + (Upper, Exclusive, Inclusive, Ordering::Equal) + | (Lower, Inclusive, Exclusive, Ordering::Equal) => Ordering::Less, + (Upper, Inclusive, Exclusive, Ordering::Equal) + | (Lower, Exclusive, Inclusive, Ordering::Equal) => Ordering::Greater, (_, _, _, ord) => ord, } } @@ -417,7 +431,10 @@ where /// Creates a new range. /// /// If a bound is `None`, the range is unbounded in that direction. - pub fn new(lower: Option>, upper: Option>) -> Range { + pub fn new( + lower: Option>, + upper: Option>, + ) -> Range { let lower = lower.map(Normalizable::normalize); let upper = upper.map(Normalizable::normalize); @@ -470,11 +487,8 @@ where match self.inner { Empty => false, Normal(ref lower, ref upper) => { - lower.as_ref().map_or(true, |b| { - b.in_bounds(value) - }) && upper.as_ref().map_or(true, |b| { - b.in_bounds(value) - }) + lower.as_ref().map_or(true, |b| b.in_bounds(value)) + && upper.as_ref().map_or(true, |b| b.in_bounds(value)) } } } @@ -489,7 +503,8 @@ where return false; } - OptBound(self.lower()) <= OptBound(other.lower()) && OptBound(self.upper()) >= OptBound(other.upper()) + OptBound(self.lower()) <= OptBound(other.lower()) + && OptBound(self.upper()) >= OptBound(other.upper()) } } @@ -530,8 +545,10 @@ where return Some(self.clone()); } - let (OptBound(l_lower), OptBound(u_lower)) = order(OptBound(self.lower()), OptBound(other.lower())); - let (OptBound(l_upper), OptBound(u_upper)) = order(OptBound(self.upper()), OptBound(other.upper())); + let (OptBound(l_lower), OptBound(u_lower)) = + order(OptBound(self.lower()), OptBound(other.lower())); + let (OptBound(l_upper), OptBound(u_upper)) = + order(OptBound(self.upper()), OptBound(other.upper())); let discontiguous = match (u_lower, l_upper) { ( @@ -546,17 +563,16 @@ where .. }), ) => l >= u, - (Some(&RangeBound { value: ref l, .. }), Some(&RangeBound { value: ref u, .. })) => l > u, + (Some(&RangeBound { value: ref l, .. }), Some(&RangeBound { value: ref u, .. })) => { + l > u + } _ => false, }; if discontiguous { None } else { - Some(Range::new( - l_lower.cloned(), - u_upper.cloned(), - )) + Some(Range::new(l_lower.cloned(), u_upper.cloned())) } } } @@ -565,8 +581,8 @@ where mod test { use std::i32; - use super::{BoundType, LowerBound, Normalizable, Range, RangeBound, UpperBound}; use super::BoundType::{Exclusive, Inclusive}; + use super::{BoundType, LowerBound, Normalizable, Range, RangeBound, UpperBound}; #[test] fn test_range_bound_lower_lt() { diff --git a/src/rust_decimal_1.rs b/src/rust_decimal_1.rs new file mode 100644 index 0000000..558aaf5 --- /dev/null +++ b/src/rust_decimal_1.rs @@ -0,0 +1,12 @@ +use rust_decimal_1::Decimal; + +use crate::{BoundSided, Normalizable, RangeBound}; + +impl Normalizable for Decimal { + fn normalize(bound: RangeBound) -> RangeBound + where + S: BoundSided, + { + bound + } +}