11using System ;
2+ using System . Data ;
23using System . Dynamic ;
34using System . Linq ;
45using System . Linq . Expressions ;
@@ -242,10 +243,13 @@ constant.Value is CallSite site &&
242243 protected HqlTreeNode VisitNhAverage ( NhAverageExpression expression )
243244 {
244245 var hqlExpression = VisitExpression ( expression . Expression ) . AsExpression ( ) ;
245- if ( expression . Type != expression . Expression . Type )
246- hqlExpression = _hqlTreeBuilder . Cast ( hqlExpression , expression . Type ) ;
246+ hqlExpression = IsCastRequired ( expression . Expression , expression . Type )
247+ ? ( HqlExpression ) _hqlTreeBuilder . Cast ( hqlExpression , expression . Type )
248+ : _hqlTreeBuilder . TransparentCast ( hqlExpression , expression . Type ) ;
247249
248- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
250+ return IsCastRequired ( expression . Type , "avg" )
251+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type )
252+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Average ( hqlExpression ) , expression . Type ) ;
249253 }
250254
251255 protected HqlTreeNode VisitNhCount ( NhCountExpression expression )
@@ -265,17 +269,9 @@ protected HqlTreeNode VisitNhMax(NhMaxExpression expression)
265269
266270 protected HqlTreeNode VisitNhSum ( NhSumExpression expression )
267271 {
268- var type = expression . Type . UnwrapIfNullable ( ) ;
269- var nhType = TypeFactory . GetDefaultTypeFor ( type ) ;
270- if ( nhType != null && _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( "sum" )
271- ? . ReturnType ( nhType , _parameters . SessionFactory ) ? . ReturnedClass == type )
272- {
273- return _hqlTreeBuilder . TransparentCast (
274- _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) ,
275- expression . Type ) ;
276- }
277-
278- return _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
272+ return IsCastRequired ( expression . Type , "sum" )
273+ ? ( HqlTreeNode ) _hqlTreeBuilder . Cast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type )
274+ : _hqlTreeBuilder . TransparentCast ( _hqlTreeBuilder . Sum ( VisitExpression ( expression . Expression ) . AsExpression ( ) ) , expression . Type ) ;
279275 }
280276
281277 protected HqlTreeNode VisitNhDistinct ( NhDistinctExpression expression )
@@ -489,15 +485,9 @@ protected HqlTreeNode VisitUnaryExpression(UnaryExpression expression)
489485 case ExpressionType . Convert :
490486 case ExpressionType . ConvertChecked :
491487 case ExpressionType . TypeAs :
492- var operandType = expression . Operand . Type . UnwrapIfNullable ( ) ;
493- if ( ( operandType . IsPrimitive || operandType == typeof ( decimal ) ) &&
494- ( expression . Type . IsPrimitive || expression . Type == typeof ( decimal ) ) &&
495- expression . Type != operandType )
496- {
497- return _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type ) ;
498- }
499-
500- return VisitExpression ( expression . Operand ) ;
488+ return IsCastRequired ( expression . Operand , expression . Type )
489+ ? _hqlTreeBuilder . Cast ( VisitExpression ( expression . Operand ) . AsExpression ( ) , expression . Type )
490+ : VisitExpression ( expression . Operand ) ;
501491 }
502492
503493 throw new NotSupportedException ( expression . ToString ( ) ) ;
@@ -598,5 +588,96 @@ protected HqlTreeNode VisitNewArrayExpression(NewArrayExpression expression)
598588 var expressionSubTree = expression . Expressions . Select ( exp => VisitExpression ( exp ) ) . ToArray ( ) ;
599589 return _hqlTreeBuilder . ExpressionSubTreeHolder ( expressionSubTree ) ;
600590 }
591+
592+ private bool IsCastRequired ( Expression expression , System . Type toType )
593+ {
594+ return toType != typeof ( object ) && IsCastRequired ( GetType ( expression ) , TypeFactory . GetDefaultTypeFor ( toType ) ) ;
595+ }
596+
597+ private bool IsCastRequired ( IType type , IType toType )
598+ {
599+ // A type can be null when casting an entity into a base class, in that case we should not cast
600+ if ( type == null || toType == null || Equals ( type , toType ) )
601+ {
602+ return false ;
603+ }
604+
605+ var sqlTypes = type . SqlTypes ( _parameters . SessionFactory ) ;
606+ var toSqlTypes = toType . SqlTypes ( _parameters . SessionFactory ) ;
607+ if ( sqlTypes . Length != 1 || toSqlTypes . Length != 1 )
608+ {
609+ return false ; // Casting a multi-column type is not possible
610+ }
611+
612+ if ( type . ReturnedClass . IsEnum && sqlTypes [ 0 ] . DbType == DbType . String )
613+ {
614+ return false ; // Never cast an enum that is mapped as string, the type will provide a string for the parameter value
615+ }
616+
617+ return sqlTypes [ 0 ] . DbType != toSqlTypes [ 0 ] . DbType ;
618+ }
619+
620+ private bool IsCastRequired ( System . Type type , string sqlFunctionName )
621+ {
622+ if ( type == typeof ( object ) )
623+ {
624+ return false ;
625+ }
626+
627+ var toType = TypeFactory . GetDefaultTypeFor ( type ) ;
628+ if ( toType == null )
629+ {
630+ return true ; // Fallback to the old behavior
631+ }
632+
633+ var sqlFunction = _parameters . SessionFactory . SQLFunctionRegistry . FindSQLFunction ( sqlFunctionName ) ;
634+ if ( sqlFunction == null )
635+ {
636+ return true ; // Fallback to the old behavior
637+ }
638+
639+ var fnReturnType = sqlFunction . ReturnType ( toType , _parameters . SessionFactory ) ;
640+ return fnReturnType == null || IsCastRequired ( fnReturnType , toType ) ;
641+ }
642+
643+ private IType GetType ( Expression expression )
644+ {
645+ if ( ! ( expression is MemberExpression memberExpression ) )
646+ {
647+ return expression . Type != typeof ( object )
648+ ? TypeFactory . GetDefaultTypeFor ( expression . Type )
649+ : null ;
650+ }
651+
652+ // Try to get the mapped type for the member as it may be a non default one
653+ var entityName = TryGetEntityName ( memberExpression ) ;
654+ if ( entityName == null )
655+ {
656+ return TypeFactory . GetDefaultTypeFor ( expression . Type ) ; // Not mapped
657+ }
658+
659+ var persister = _parameters . SessionFactory . GetEntityPersister ( entityName ) ;
660+ var index = persister . EntityMetamodel . GetPropertyIndexOrNull ( memberExpression . Member . Name ) ;
661+ return ! index . HasValue
662+ ? TypeFactory . GetDefaultTypeFor ( expression . Type ) // Not mapped
663+ : persister . EntityMetamodel . PropertyTypes [ index . Value ] ;
664+ }
665+
666+ private string TryGetEntityName ( MemberExpression memberExpression )
667+ {
668+ System . Type entityType ;
669+ // Try to get the actual entity type from the query source if possbile as member can be declared
670+ // in a base type
671+ if ( memberExpression . Expression is QuerySourceReferenceExpression querySourceReferenceExpression )
672+ {
673+ entityType = querySourceReferenceExpression . Type ;
674+ }
675+ else
676+ {
677+ entityType = memberExpression . Member . ReflectedType ;
678+ }
679+
680+ return _parameters . SessionFactory . TryGetGuessEntityName ( entityType ) ;
681+ }
601682 }
602683}
0 commit comments