@@ -935,17 +935,17 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
935935 Def ("vfloatn" , ["vfloatn" , "float" ], invoke_name = "fmax_common" , convert_args = [(1 ,0 )]),
936936 Def ("vdoublen" , ["vdoublen" , "double" ], invoke_name = "fmax_common" , convert_args = [(1 ,0 )]),
937937 Def ("vhalfn" , ["vhalfn" , "half" ], invoke_name = "fmax_common" , convert_args = [(1 ,0 )]), # Non-standard. Deprecated.
938- Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_max" , marray_use_loop = True ),
939- Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_max" , marray_use_loop = True ),
938+ Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_max" , marray_use_loop = True , template_scalar_args = True ),
939+ Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_max" , marray_use_loop = True , template_scalar_args = True ),
940940 Def ("vigeninteger" , ["vigeninteger" , "elementtype0" ], invoke_name = "s_max" ),
941941 Def ("vugeninteger" , ["vugeninteger" , "elementtype0" ], invoke_name = "u_max" ),
942942 Def ("mgentype" , ["mgentype" , "elementtype0" ], marray_use_loop = True )],
943943 "(min)" : [Def ("genfloat" , ["genfloat" , "genfloat" ], invoke_name = "fmin_common" , template_scalar_args = True ),
944944 Def ("vfloatn" , ["vfloatn" , "float" ], invoke_name = "fmin_common" , convert_args = [(1 ,0 )]),
945945 Def ("vdoublen" , ["vdoublen" , "double" ], invoke_name = "fmin_common" , convert_args = [(1 ,0 )]),
946946 Def ("vhalfn" , ["vhalfn" , "half" ], invoke_name = "fmin_common" , convert_args = [(1 ,0 )]), # Non-standard. Deprecated.
947- Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_min" , marray_use_loop = True ),
948- Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_min" , marray_use_loop = True ),
947+ Def ("igeninteger" , ["igeninteger" , "igeninteger" ], invoke_name = "s_min" , marray_use_loop = True , template_scalar_args = True ),
948+ Def ("ugeninteger" , ["ugeninteger" , "ugeninteger" ], invoke_name = "u_min" , marray_use_loop = True , template_scalar_args = True ),
949949 Def ("vigeninteger" , ["vigeninteger" , "elementtype0" ], invoke_name = "s_min" ),
950950 Def ("vugeninteger" , ["vugeninteger" , "elementtype0" ], invoke_name = "u_min" ),
951951 Def ("mgentype" , ["mgentype" , "elementtype0" ], marray_use_loop = True )],
@@ -957,7 +957,7 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
957957 Def ("mdoublen" , ["mdoublen" , "mdoublen" , "double" ]),
958958 Def ("mhalfn" , ["mhalfn" , "mhalfn" , "half" ])], # Non-standard. Deprecated.
959959 "radians" : [Def ("genfloat" , ["genfloat" ], template_scalar_args = True )],
960- "step" : [Def ("genfloat" , ["genfloat" , "genfloat" ]),
960+ "step" : [Def ("genfloat" , ["genfloat" , "genfloat" ], template_scalar_args = True ),
961961 Def ("vfloatn" , ["float" , "vfloatn" ], convert_args = [(0 ,1 )]),
962962 Def ("vdoublen" , ["double" , "vdoublen" ], convert_args = [(0 ,1 )]),
963963 Def ("vhalfn" , ["half" , "vhalfn" ], convert_args = [(0 ,1 )]), # Non-standard. Deprecated.
@@ -989,25 +989,25 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
989989 Def ("float" , ["mgeofloat" , "mgeofloat" ], invoke_name = "Dot" ),
990990 Def ("double" , ["mgeodouble" , "mgeodouble" ], invoke_name = "Dot" ),
991991 Def ("half" , ["mgeohalf" , "mgeohalf" ], invoke_name = "Dot" ),
992- Def ("sgenfloat" , ["sgenfloat" , "sgenfloat" ], custom_invoke = (lambda return_types , arg_types , arg_names : ' return ' + ' * ' .join (arg_names ) + ';' ))],
993- "distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ]),
994- Def ("double" , ["gengeodouble" , "gengeodouble" ]),
995- Def ("half" , ["gengeohalf" , "gengeohalf" ])],
996- "length" : [Def ("float" , ["gengeofloat" ]),
997- Def ("double" , ["gengeodouble" ]),
998- Def ("half" , ["gengeohalf" ])],
999- "normalize" : [Def ("gengeofloat" , ["gengeofloat" ]),
1000- Def ("gengeodouble" , ["gengeodouble" ]),
1001- Def ("gengeohalf" , ["gengeohalf" ])],
1002- "fast_distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ]),
1003- Def ("double" , ["gengeodouble" , "gengeodouble" ]),
1004- Def ("half" , ["gengeohalf" , "gengeohalf" ])],
1005- "fast_length" : [Def ("float" , ["gengeofloat" ]),
1006- Def ("double" , ["gengeodouble" ]),
1007- Def ("half" , ["gengeohalf" ])],
1008- "fast_normalize" : [Def ("gengeofloat" , ["gengeofloat" ]),
1009- Def ("gengeodouble" , ["gengeodouble" ]),
1010- Def ("gengeohalf" , ["gengeohalf" ])],
992+ Def ("sgenfloat" , ["sgenfloat" , "sgenfloat" ], template_scalar_args = True , custom_invoke = (lambda return_types , arg_types , arg_names : ' return ' + ' * ' .join (arg_names ) + ';' ))],
993+ "distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ], template_scalar_args = True ),
994+ Def ("double" , ["gengeodouble" , "gengeodouble" ], template_scalar_args = True ),
995+ Def ("half" , ["gengeohalf" , "gengeohalf" ], template_scalar_args = True )],
996+ "length" : [Def ("float" , ["gengeofloat" ], template_scalar_args = True ),
997+ Def ("double" , ["gengeodouble" ], template_scalar_args = True ),
998+ Def ("half" , ["gengeohalf" ], template_scalar_args = True )],
999+ "normalize" : [Def ("gengeofloat" , ["gengeofloat" ], template_scalar_args = True ),
1000+ Def ("gengeodouble" , ["gengeodouble" ], template_scalar_args = True ),
1001+ Def ("gengeohalf" , ["gengeohalf" ], template_scalar_args = True )],
1002+ "fast_distance" : [Def ("float" , ["gengeofloat" , "gengeofloat" ], template_scalar_args = True ),
1003+ Def ("double" , ["gengeodouble" , "gengeodouble" ], template_scalar_args = True ),
1004+ Def ("half" , ["gengeohalf" , "gengeohalf" ], template_scalar_args = True )],
1005+ "fast_length" : [Def ("float" , ["gengeofloat" ], template_scalar_args = True ),
1006+ Def ("double" , ["gengeodouble" ], template_scalar_args = True ),
1007+ Def ("half" , ["gengeohalf" ], template_scalar_args = True )],
1008+ "fast_normalize" : [Def ("gengeofloat" , ["gengeofloat" ], template_scalar_args = True ),
1009+ Def ("gengeodouble" , ["gengeodouble" ], template_scalar_args = True ),
1010+ Def ("gengeohalf" , ["gengeohalf" ], template_scalar_args = True )],
10111011 # Relational functions
10121012 "isequal" : [RelDef ("samesizesignedint0" , ["vgenfloat" , "vgenfloat" ], invoke_name = "FOrdEqual" ),
10131013 RelDef ("bool" , ["sgenfloat" , "sgenfloat" ], invoke_name = "FOrdEqual" ),
@@ -1052,13 +1052,13 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
10521052 RelDef ("bool" , ["sgenfloat" ], invoke_name = "SignBitSet" ),
10531053 RelDef ("boolelements0" , ["mgenfloat" ])],
10541054 "any" : [Def ("int" , ["vigeninteger" ], custom_invoke = get_custom_any_all_vec_invoke ("Any" )),
1055- Def ("bool" , ["sigeninteger" ], custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
1055+ Def ("bool" , ["sigeninteger" ], template_scalar_args = True , custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
10561056 Def ("bool" , ["migeninteger" ], custom_invoke = get_custom_any_all_marray_invoke ("any" ))],
10571057 "all" : [Def ("int" , ["vigeninteger" ], custom_invoke = get_custom_any_all_vec_invoke ("All" )),
1058- Def ("bool" , ["sigeninteger" ], custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
1058+ Def ("bool" , ["sigeninteger" ], template_scalar_args = True , custom_invoke = (lambda return_type , arg_types , arg_names : f' return bool(int(detail::msbIsSet({ arg_names [0 ]} )));' )),
10591059 Def ("bool" , ["migeninteger" ], custom_invoke = get_custom_any_all_marray_invoke ("all" ))],
10601060 "bitselect" : [Def ("vgentype" , ["vgentype" , "vgentype" , "vgentype" ]),
1061- Def ("sgentype" , ["sgentype" , "sgentype" , "sgentype" ]),
1061+ Def ("sgentype" , ["sgentype" , "sgentype" , "sgentype" ], template_scalar_args = True ),
10621062 Def ("mgentype" , ["mgentype" , "mgentype" , "mgentype" ], marray_use_loop = True )],
10631063 "select" : [Def ("vint8n" , ["vint8n" , "vint8n" , "vint8n" ]),
10641064 Def ("vint16n" , ["vint16n" , "vint16n" , "vint16n" ]),
@@ -1082,7 +1082,7 @@ def custom_nan_invoke(return_type, arg_types, arg_names):
10821082 Def ("vfloatn" , ["vfloatn" , "vfloatn" , "vuint32n" ]),
10831083 Def ("vdoublen" , ["vdoublen" , "vdoublen" , "vuint64n" ]),
10841084 Def ("vhalfn" , ["vhalfn" , "vhalfn" , "vuint16n" ]),
1085- Def ("sgentype" , ["sgentype" , "sgentype" , "bool" ], custom_invoke = custom_bool_select_invoke ),
1085+ Def ("sgentype" , ["sgentype" , "sgentype" , "bool" ], template_scalar_args = True , custom_invoke = custom_bool_select_invoke ),
10861086 Def ("mgentype" , ["mgentype" , "mgentype" , "mbooln" ], marray_use_loop = True )]}
10871087# List of all builtins definitions in the sycl::native namespace.
10881088native_builtins = {"cos" : [Def ("genfloatf" , ["genfloatf" ], invoke_prefix = "native_" )],
@@ -1210,10 +1210,15 @@ def type_combinations(return_type, arg_types, template_scalars):
12101210 Generates all return and argument type combinations for a given builtin
12111211 definition.
12121212 """
1213- unique_types = list (dict .fromkeys (arg_types + [ return_type ] ))
1213+ unique_types = list (dict .fromkeys (arg_types ))
12141214 unique_type_lists = [builtin_types [unique_type ] for unique_type in unique_types ]
12151215 if template_scalars :
12161216 unique_type_lists = [convert_scalars_to_templated (type_list ) for type_list in unique_type_lists ]
1217+ if return_type not in unique_types :
1218+ # Add return type after scalars have been turned to template arguments if
1219+ # it is unique, to avoid undeducible return types.
1220+ unique_types .append (return_type )
1221+ unique_type_lists .append (builtin_types [return_type ])
12171222 combinations = list (itertools .product (* unique_type_lists ))
12181223 result = []
12191224 for combination in combinations :
0 commit comments