@@ -21,7 +21,10 @@ use super::datetime::{
21
21
float_as_duration, float_as_time, int_as_datetime, int_as_duration, int_as_time, EitherDate , EitherDateTime ,
22
22
EitherTime ,
23
23
} ;
24
- use super :: shared:: { decimal_as_int, float_as_int, int_as_bool, map_json_err, str_as_bool, str_as_float, str_as_int} ;
24
+ use super :: shared:: {
25
+ decimal_as_int, float_as_int, get_enum_meta_object, int_as_bool, map_json_err, str_as_bool, str_as_float,
26
+ str_as_int,
27
+ } ;
25
28
use super :: {
26
29
py_string_str, BorrowInput , EitherBytes , EitherFloat , EitherInt , EitherString , EitherTimedelta , GenericArguments ,
27
30
GenericIterable , GenericIterator , GenericMapping , Input , JsonInput , PyArgs ,
@@ -256,6 +259,8 @@ impl<'a> Input<'a> for PyAny {
256
259
|| self . is_instance ( decimal_type. as_ref ( py) ) . unwrap_or_default ( )
257
260
} {
258
261
Ok ( self . str ( ) ?. into ( ) )
262
+ } else if let Some ( enum_val) = maybe_as_enum ( self ) {
263
+ Ok ( enum_val. str ( ) ?. into ( ) )
259
264
} else {
260
265
Err ( ValError :: new ( ErrorTypeDefaults :: StringType , self ) )
261
266
}
@@ -340,6 +345,8 @@ impl<'a> Input<'a> for PyAny {
340
345
decimal_as_int ( self . py ( ) , self , decimal)
341
346
} else if let Ok ( float) = self . extract :: < f64 > ( ) {
342
347
float_as_int ( self , float)
348
+ } else if let Some ( enum_val) = maybe_as_enum ( self ) {
349
+ Ok ( EitherInt :: Py ( enum_val) )
343
350
} else {
344
351
Err ( ValError :: new ( ErrorTypeDefaults :: IntType , self ) )
345
352
}
@@ -759,6 +766,18 @@ fn maybe_as_string(v: &PyAny, unicode_error: ErrorType) -> ValResult<Option<Cow<
759
766
}
760
767
}
761
768
769
+ /// Utility for extracting an enum value, if possible.
770
+ fn maybe_as_enum ( v : & PyAny ) -> Option < & PyAny > {
771
+ let py = v. py ( ) ;
772
+ let enum_meta_object = get_enum_meta_object ( py) ;
773
+ let meta_type = v. get_type ( ) . get_type ( ) ;
774
+ if meta_type. is ( & enum_meta_object) {
775
+ v. getattr ( intern ! ( py, "value" ) ) . ok ( )
776
+ } else {
777
+ None
778
+ }
779
+ }
780
+
762
781
#[ cfg( PyPy ) ]
763
782
static DICT_KEYS_TYPE : pyo3:: once_cell:: GILOnceCell < Py < PyType > > = pyo3:: once_cell:: GILOnceCell :: new ( ) ;
764
783
0 commit comments