5353log = logging .getLogger (__name__ )
5454escaper = TokenEscaper ()
5555
56+
57+ class PartialModel :
58+ """A partial model instance that only contains certain fields.
59+
60+ Accessing fields that weren't loaded will raise AttributeError.
61+ This is used for .only() queries to provide partial model instances.
62+ """
63+
64+ def __init__ (self , model_class , data : dict , loaded_fields : set ):
65+ self .__dict__ ["_model_class" ] = model_class
66+ self .__dict__ ["_loaded_fields" ] = loaded_fields
67+ self .__dict__ ["_data" ] = data
68+
69+ # Set the loaded field values
70+ for field_name , value in data .items ():
71+ self .__dict__ [field_name ] = value
72+
73+ def __getattribute__ (self , name ):
74+ # Allow access to internal attributes and methods
75+ if name .startswith ("_" ) or name in (
76+ "model_fields" ,
77+ "model_config" ,
78+ "__class__" ,
79+ "__dict__" ,
80+ ):
81+ return super ().__getattribute__ (name )
82+
83+ # Get model class to check if this is a model field
84+ model_class = super ().__getattribute__ ("_model_class" )
85+ loaded_fields = super ().__getattribute__ ("_loaded_fields" )
86+
87+ # If it's a model field that wasn't loaded, raise an error
88+ if hasattr (model_class , "model_fields" ) and name in model_class .model_fields :
89+ if name not in loaded_fields :
90+ raise AttributeError (
91+ f"Field '{ name } ' is missing from this query. "
92+ f"Use .only('{ name } ') or .only({ ', ' .join (repr (field ) for field in sorted (loaded_fields .union ({name })))} ) to include it."
93+ )
94+
95+ return super ().__getattribute__ (name )
96+
97+ def __setattr__ (self , name , value ):
98+ # Allow setting internal attributes
99+ if name .startswith ("_" ):
100+ self .__dict__ [name ] = value
101+ else :
102+ # For regular fields, check if they were loaded
103+ if name not in self ._loaded_fields :
104+ raise AttributeError (
105+ f"Cannot set field '{ name } ' - it is missing from this query."
106+ )
107+ self .__dict__ [name ] = value
108+
109+ def __repr__ (self ):
110+ loaded_data = {k : v for k , v in self ._data .items () if k in self ._loaded_fields }
111+ return f"Partial{ self ._model_class .__name__ } ({ loaded_data } )"
112+
113+
56114# For basic exact-match field types like an indexed string, we create a TAG
57115# field in the RediSearch index. TAG is designed for multi-value fields
58116# separated by a "separator" character. We're using the field for single values
@@ -503,7 +561,7 @@ def query(self):
503561 """
504562 if self ._query :
505563 return self ._query
506- self ._query = self .resolve_redisearch_query (self .expression )
564+ self ._query = self ._resolve_redisearch_query (self .expression )
507565 if self .knn :
508566 self ._query = (
509567 self ._query
@@ -541,15 +599,98 @@ def to_string(s):
541599 if res [i + offset ] is None :
542600 continue
543601 # When using RETURN, we get flat key-value pairs
544- fields : Dict [str , str ] = dict (
602+ raw_fields : Dict [str , str ] = dict (
545603 zip (
546604 map (to_string , res [i + offset ][::2 ]),
547605 map (to_string , res [i + offset ][1 ::2 ]),
548606 )
549607 )
550- docs .append (fields )
608+ # Convert raw Redis strings to properly typed values
609+ converted_fields = self ._convert_projected_fields (raw_fields )
610+ docs .append (converted_fields )
551611 return docs
552612
613+ def _convert_projected_fields (self , raw_data : Dict [str , str ]) -> Dict [str , Any ]:
614+ """Convert raw Redis string values to properly typed values using model field info."""
615+
616+ # Fast path: Try creating a single model instance with all projected fields
617+ # This is more efficient and handles field interdependencies
618+ try :
619+ # Use model_validate instead of model_construct to ensure type conversion
620+ temp_model = self .model .model_validate (raw_data , strict = False )
621+
622+ # Use model_dump() to efficiently extract all converted values
623+ all_converted = temp_model .model_dump ()
624+
625+ # Filter to only the fields we actually projected
626+ converted_data = {
627+ k : all_converted [k ] for k in raw_data .keys () if k in all_converted
628+ }
629+
630+ return converted_data
631+
632+ except Exception : # nosec B110
633+ # If validation fails (due to missing required fields), fall back to individual conversion
634+ # This is expected for partial field sets
635+ pass
636+
637+ # Fallback path: Convert each field individually using type information
638+ converted_data = {}
639+ for field_name , raw_value in raw_data .items ():
640+ if field_name not in self .model .model_fields :
641+ # Unknown field, keep as string
642+ converted_data [field_name ] = raw_value
643+ continue
644+
645+ try :
646+ field_info = self .model .model_fields [field_name ]
647+
648+ # Get the field type annotation
649+ if hasattr (field_info , "annotation" ):
650+ field_type = field_info .annotation
651+ else :
652+ field_type = getattr (field_info , "type_" , str )
653+
654+ # Handle common type conversions directly for efficiency
655+ if field_type == int :
656+ converted_data [field_name ] = int (raw_value )
657+ elif field_type == float :
658+ converted_data [field_name ] = float (raw_value )
659+ elif field_type == bool :
660+ # Redis may store bool as "True"/"False" or "1"/"0"
661+ converted_data [field_name ] = raw_value .lower () in (
662+ "true" ,
663+ "1" ,
664+ "yes" ,
665+ )
666+ elif field_type == str :
667+ converted_data [field_name ] = raw_value
668+ else :
669+ # For complex types, keep as string (could be enhanced later)
670+ converted_data [field_name ] = raw_value
671+
672+ except (ValueError , TypeError ):
673+ # If conversion fails, keep the raw string value
674+ converted_data [field_name ] = raw_value
675+
676+ return converted_data
677+
678+ def _parse_projected_models (self , res : Any ) -> List [PartialModel ]:
679+ """Parse results when using RETURN clause to create partial model instances."""
680+ projected_dicts = self ._parse_projected_results (res )
681+
682+ # Create partial model instances that will raise errors for missing fields
683+ partial_models = []
684+ for data in projected_dicts :
685+ partial_model = PartialModel (
686+ model_class = self .model ,
687+ data = data ,
688+ loaded_fields = set (self .projected_fields ),
689+ )
690+ partial_models .append (partial_model )
691+
692+ return partial_models
693+
553694 @property
554695 def query_params (self ):
555696 params : List [Union [str , bytes ]] = []
@@ -669,6 +810,7 @@ def resolve_value(
669810 op : Operators ,
670811 value : Any ,
671812 parents : List [Tuple [str , "RedisModel" ]],
813+ model_class : Optional [Type ["RedisModel" ]] = None ,
672814 ) -> str :
673815 # The 'field_name' should already include the correct prefix
674816 result = ""
@@ -724,8 +866,18 @@ def resolve_value(
724866 )
725867 return ""
726868 if isinstance (value , bool ):
869+ # For HashModel, convert boolean to "1"/"0" to match storage format
870+ # For JsonModel, keep as boolean since JSON supports native booleans
871+ if model_class :
872+ # Check if this is a HashModel by checking the class hierarchy
873+ is_hash_model = any (
874+ base .__name__ == "HashModel" for base in model_class .__mro__
875+ )
876+ bool_value = ("1" if value else "0" ) if is_hash_model else value
877+ else :
878+ bool_value = value
727879 result = "@{field_name}:{{{value}}}" .format (
728- field_name = field_name , value = value
880+ field_name = field_name , value = bool_value
729881 )
730882 elif isinstance (value , int ):
731883 # This if will hit only if the field is a primary key of type int
@@ -803,8 +955,7 @@ def resolve_redisearch_sort_fields(self):
803955 if self .sort_fields :
804956 return ["SORTBY" , * fields ]
805957
806- @classmethod
807- def resolve_redisearch_query (cls , expression : ExpressionOrNegated ) -> str :
958+ def _resolve_redisearch_query (self , expression : ExpressionOrNegated ) -> str :
808959 """
809960 Resolve an arbitrarily deep expression into a single RediSearch query string.
810961
@@ -848,9 +999,11 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
848999 if isinstance (expression .left , Expression ) or isinstance (
8491000 expression .left , NegatedExpression
8501001 ):
851- result += f"({ cls . resolve_redisearch_query (expression .left )} )"
1002+ result += f"({ self . _resolve_redisearch_query (expression .left )} )"
8521003 elif isinstance (expression .left , FieldInfo ):
853- field_type = cls .resolve_field_type (expression .left , expression .op )
1004+ field_type = self .__class__ .resolve_field_type (
1005+ expression .left , expression .op
1006+ )
8541007 field_name = expression .left .name
8551008 field_info = expression .left
8561009 if not field_info or not getattr (field_info , "index" , None ):
@@ -881,7 +1034,7 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
8811034 result += "-"
8821035 right = right .expression
8831036
884- result += f"({ cls . resolve_redisearch_query (right )} )"
1037+ result += f"({ self . _resolve_redisearch_query (right )} )"
8851038 else :
8861039 if not field_name :
8871040 raise QuerySyntaxError ("Could not resolve field name. See docs: TODO" )
@@ -890,13 +1043,14 @@ def resolve_redisearch_query(cls, expression: ExpressionOrNegated) -> str:
8901043 elif not field_info :
8911044 raise QuerySyntaxError ("Could not resolve field info. See docs: TODO" )
8921045 else :
893- result += cls .resolve_value (
1046+ result += self . __class__ .resolve_value (
8941047 field_name ,
8951048 field_type ,
8961049 field_info ,
8971050 expression .op ,
8981051 right ,
8991052 expression .parents ,
1053+ self .model ,
9001054 )
9011055
9021056 if encompassing_expression_is_negated :
@@ -951,16 +1105,19 @@ async def execute(
9511105 return raw_result
9521106 count = raw_result [0 ]
9531107
954- # If we're using field projection or explicitly requesting dict output,
955- # return dictionaries instead of model instances
956- if self .projected_fields or self .return_as_dict :
957- if self .projected_fields :
958- results = self ._parse_projected_results (raw_result )
959- else :
960- # Return all fields as dicts - need to convert from model instances
961- model_results = self .model .from_redis (raw_result , self .knn )
962- results = [model .model_dump () for model in model_results ]
1108+ # Handle different result processing based on what was requested
1109+ if self .projected_fields and self .return_as_dict :
1110+ # .values('field1', 'field2') - specific fields as dicts
1111+ results = self ._parse_projected_results (raw_result )
1112+ elif self .projected_fields and not self .return_as_dict :
1113+ # .only('field1', 'field2') - partial model instances
1114+ results = self ._parse_projected_models (raw_result )
1115+ elif self .return_as_dict and not self .projected_fields :
1116+ # .values() - all fields as dicts
1117+ model_results = self .model .from_redis (raw_result , self .knn )
1118+ results = [model .model_dump () for model in model_results ]
9631119 else :
1120+ # Normal query - full model instances
9641121 results = self .model .from_redis (raw_result , self .knn )
9651122 self ._model_cache += results
9661123
@@ -1019,10 +1176,10 @@ def sort_by(self, *fields: str):
10191176 def values (self , * fields : str ):
10201177 """
10211178 Return query results as dictionaries instead of model instances.
1022-
1179+
10231180 If no fields are specified, returns all fields.
10241181 If fields are specified, returns only those fields.
1025-
1182+
10261183 Usage:
10271184 await Model.find().values() # All fields as dicts
10281185 await Model.find().values('name', 'email') # Only specified fields
@@ -1034,6 +1191,20 @@ def values(self, *fields: str):
10341191 # Return specific fields as dicts
10351192 return self .copy (return_as_dict = True , projected_fields = list (fields ))
10361193
1194+ def only (self , * fields : str ):
1195+ """
1196+ Return query results as model instances with only the specified fields loaded.
1197+
1198+ Accessing fields that weren't loaded will raise an AttributeError.
1199+ Uses Redis RETURN clause for efficient field projection.
1200+
1201+ Usage:
1202+ await Model.find().only('name', 'email').all() # Partial model instances
1203+ """
1204+ if not fields :
1205+ raise ValueError ("only() requires at least one field name" )
1206+ return self .copy (projected_fields = list (fields ))
1207+
10371208 async def update (self , use_transaction = True , ** field_values ):
10381209 """
10391210 Update models that match this query to the given field-value pairs.
@@ -1766,6 +1937,13 @@ async def save(
17661937
17671938 # filter out values which are `None` because they are not valid in a HSET
17681939 document = {k : v for k , v in document .items () if v is not None }
1940+
1941+ # Convert boolean values to "1"/"0" for storage efficiency (Redis HSET doesn't support booleans)
1942+ document = {
1943+ k : ("1" if v else "0" ) if isinstance (v , bool ) else v
1944+ for k , v in document .items ()
1945+ }
1946+
17691947 # TODO: Wrap any Redis response errors in a custom exception?
17701948 await db .hset (self .key (), mapping = document )
17711949 return self
0 commit comments