1414package com .facebook .presto .metadata ;
1515
1616import com .facebook .presto .operator .Description ;
17- import com .facebook .presto .operator .aggregation .GenericAggregationFunctionFactory ;
1817import com .facebook .presto .operator .aggregation .InternalAggregationFunction ;
1918import com .facebook .presto .operator .scalar .JsonPath ;
2019import com .facebook .presto .operator .scalar .ScalarFunction ;
4645import java .util .ArrayList ;
4746import java .util .Arrays ;
4847import java .util .List ;
48+ import java .util .Map ;
49+ import java .util .Objects ;
50+ import java .util .Optional ;
4951import java .util .Set ;
52+ import java .util .stream .Collectors ;
5053
5154import static com .facebook .presto .metadata .FunctionKind .SCALAR ;
5255import static com .facebook .presto .metadata .FunctionKind .WINDOW ;
5356import static com .facebook .presto .metadata .Signature .typeParameter ;
57+ import static com .facebook .presto .operator .aggregation .GenericAggregationFunctionFactory .fromAggregationDefinition ;
5458import static com .facebook .presto .spi .type .BigintType .BIGINT ;
5559import static com .facebook .presto .spi .type .TypeSignature .parseTypeSignature ;
5660import static com .google .common .base .CaseFormat .LOWER_CAMEL ;
5761import static com .google .common .base .CaseFormat .LOWER_UNDERSCORE ;
5862import static com .google .common .base .Preconditions .checkArgument ;
63+ import static com .google .common .collect .Iterables .getOnlyElement ;
5964import static java .lang .invoke .MethodHandles .lookup ;
6065import static java .util .Locale .ENGLISH ;
6166import static java .util .Objects .requireNonNull ;
67+ import static java .util .stream .Collectors .toList ;
6268
6369public class FunctionListBuilder
6470{
@@ -69,7 +75,35 @@ public class FunctionListBuilder
6975 Regex .class ,
7076 JsonPath .class );
7177
72- private final List <SqlFunction > functions = new ArrayList <>();
78+ private static final class FunctionWithMethodHandle
79+ {
80+ private final SqlFunction function ;
81+ private final Optional <MethodHandle > methodHandle ;
82+
83+ public FunctionWithMethodHandle (SqlFunction function )
84+ {
85+ this .function = function ;
86+ this .methodHandle = Optional .empty ();
87+ }
88+
89+ public FunctionWithMethodHandle (SqlFunction function , MethodHandle methodHandle )
90+ {
91+ this .function = function ;
92+ this .methodHandle = Optional .of (methodHandle );
93+ }
94+
95+ public Optional <MethodHandle > getMethodHandle ()
96+ {
97+ return methodHandle ;
98+ }
99+
100+ public SqlFunction getFunction ()
101+ {
102+ return function ;
103+ }
104+ }
105+
106+ private final List <FunctionWithMethodHandle > functions = new ArrayList <>();
73107 private final TypeManager typeManager ;
74108
75109 public FunctionListBuilder (TypeManager typeManager )
@@ -83,14 +117,14 @@ public FunctionListBuilder window(String name, Type returnType, List<? extends T
83117 new Signature (name , WINDOW , returnType .getTypeSignature (), Lists .transform (ImmutableList .copyOf (argumentTypes ), Type ::getTypeSignature )),
84118 functionClass );
85119
86- functions . add (new SqlWindowFunction (windowFunctionSupplier ));
120+ addFunction (new SqlWindowFunction (windowFunctionSupplier ));
87121 return this ;
88122 }
89123
90124 public FunctionListBuilder window (String name , Class <? extends ValueWindowFunction > clazz , String typeVariable , String ... argumentTypes )
91125 {
92126 Signature signature = new Signature (name , WINDOW , ImmutableList .of (typeParameter (typeVariable )), typeVariable , ImmutableList .copyOf (argumentTypes ), false );
93- functions . add (new SqlWindowFunction (new ReflectionWindowFunctionSupplier <>(signature , clazz )));
127+ addFunction (new SqlWindowFunction (new ReflectionWindowFunctionSupplier <>(signature , clazz )));
94128 return this ;
95129 }
96130
@@ -108,29 +142,47 @@ public FunctionListBuilder aggregate(InternalAggregationFunction function)
108142 name = name .toLowerCase (ENGLISH );
109143
110144 String description = getDescription (function .getClass ());
111- functions . add (SqlAggregationFunction .create (name , description , function ));
145+ addFunction (SqlAggregationFunction .create (name , description , function ));
112146 return this ;
113147 }
114148
115149 public FunctionListBuilder aggregate (Class <?> aggregationDefinition )
116150 {
117- functions . addAll ( GenericAggregationFunctionFactory . fromAggregationDefinition (aggregationDefinition , typeManager ).listFunctions ());
151+ fromAggregationDefinition (aggregationDefinition , typeManager ).listFunctions (). forEach ( this :: addFunction );
118152 return this ;
119153 }
120154
121155 public FunctionListBuilder scalar (Signature signature , MethodHandle function , boolean deterministic , String description , boolean hidden , boolean nullable , List <Boolean > nullableArguments )
122156 {
123- functions . add (SqlScalarFunction .create (signature , description , hidden , function , deterministic , nullable , nullableArguments ));
157+ addFunction (SqlScalarFunction .create (signature , description , hidden , function , deterministic , nullable , nullableArguments ), function );
124158 return this ;
125159 }
126160
127161 private FunctionListBuilder operator (OperatorType operatorType , TypeSignature returnType , List <TypeSignature > parameterTypes , MethodHandle function , boolean nullable , List <Boolean > nullableArguments )
128162 {
129- functions . add (SqlOperator .create (operatorType , parameterTypes , returnType , function , nullable , nullableArguments ));
163+ addFunction (SqlOperator .create (operatorType , parameterTypes , returnType , function , nullable , nullableArguments ), function );
130164 return this ;
131165 }
132166
133167 public FunctionListBuilder scalar (Class <?> clazz )
168+ {
169+ FunctionListBuilder localFunctionListBuilder = new FunctionListBuilder (typeManager );
170+ localFunctionListBuilder .processScalarsInClass (clazz );
171+ functions .addAll (localFunctionListBuilder .functions );
172+ return this ;
173+ }
174+
175+ private void addFunction (SqlFunction function )
176+ {
177+ functions .add (new FunctionWithMethodHandle (function ));
178+ }
179+
180+ private void addFunction (SqlFunction function , MethodHandle methodHandle )
181+ {
182+ functions .add (new FunctionWithMethodHandle (function , methodHandle ));
183+ }
184+
185+ private FunctionListBuilder processScalarsInClass (Class <?> clazz )
134186 {
135187 try {
136188 boolean foundOne = false ;
@@ -143,9 +195,61 @@ public FunctionListBuilder scalar(Class<?> clazz)
143195 catch (IllegalAccessException e ) {
144196 throw Throwables .propagate (e );
145197 }
198+ groupMatchingScalars ();
146199 return this ;
147200 }
148201
202+ private FunctionListBuilder groupMatchingScalars ()
203+ {
204+ List <FunctionWithMethodHandle > newFunctions = groupMatchingScalars (functions );
205+ functions .clear ();
206+ functions .addAll (newFunctions );
207+ return this ;
208+ }
209+
210+ private List <FunctionWithMethodHandle > groupMatchingScalars (List <FunctionWithMethodHandle > inputFunctions )
211+ {
212+ inputFunctions .forEach (f -> checkArgument (f .getMethodHandle ().isPresent (), "expected function with method handle: %s" , f ));
213+ inputFunctions .forEach (f -> checkArgument (f .getFunction () instanceof SqlScalarFunction , "expected scalar function: %s" , f ));
214+ Map <Signature , List <FunctionWithMethodHandle >> groupedFunctions = inputFunctions .stream ()
215+ .collect (Collectors .groupingBy (f -> f .getFunction ().getSignature ()));
216+ List <FunctionWithMethodHandle > resultFunctions = new ArrayList <>();
217+ for (Map .Entry <Signature , List <FunctionWithMethodHandle >> entry : groupedFunctions .entrySet ()) {
218+ List <FunctionWithMethodHandle > functionsGroup = entry .getValue ();
219+ if (functionsGroup .size () == 1 ) {
220+ resultFunctions .add (getOnlyElement (functionsGroup ));
221+ }
222+ else {
223+ resultFunctions .add (new FunctionWithMethodHandle (buildGroupingScalarWrapper (functionsGroup )));
224+ }
225+ }
226+ return resultFunctions ;
227+ }
228+
229+ private SqlScalarFunction buildGroupingScalarWrapper (List <FunctionWithMethodHandle > functionsGroup )
230+ {
231+ checkArgument (functionsGroup .size () > 1 , "functions group must have multiple elements" );
232+ SqlFunction masterFunction = functionsGroup .get (0 ).getFunction ();
233+ functionsGroup .forEach (f -> checkMatchingScalarsConsistent (masterFunction , f .getFunction ()));
234+
235+ SqlScalarFunctionBuilder wrapperBuilder = SqlScalarFunction .builder ()
236+ .signature (masterFunction .getSignature ())
237+ .deterministic (masterFunction .isDeterministic ())
238+ .hidden (masterFunction .isHidden ())
239+ .description (masterFunction .getDescription ());
240+
241+ functionsGroup .forEach (f -> wrapperBuilder .method ((SqlScalarFunction ) f .getFunction (), f .getMethodHandle ().get ()));
242+ return wrapperBuilder .build ();
243+ }
244+
245+ private void checkMatchingScalarsConsistent (SqlFunction master , SqlFunction canditate )
246+ {
247+ checkArgument (canditate .getSignature ().equals (master .getSignature ()), "signature mismatch; %s vs. %s" , master , canditate );
248+ checkArgument (canditate .isHidden () == master .isHidden (), "hidden flag mismatch; %s vs. %s" , master , canditate );
249+ checkArgument (canditate .isDeterministic () == master .isDeterministic (), "deterministic flag mismatch; %s vs. %s" , master , canditate );
250+ checkArgument (Objects .equals (canditate .getDescription (), master .getDescription ()), "description mismatch, %s vs. %s" , master , canditate );
251+ }
252+
149253 public FunctionListBuilder functions (SqlFunction ... sqlFunctions )
150254 {
151255 for (SqlFunction sqlFunction : sqlFunctions ) {
@@ -157,7 +261,7 @@ public FunctionListBuilder functions(SqlFunction... sqlFunctions)
157261 public FunctionListBuilder function (SqlFunction sqlFunction )
158262 {
159263 requireNonNull (sqlFunction , "parametricFunction is null" );
160- functions . add (sqlFunction );
264+ addFunction (sqlFunction );
161265 return this ;
162266 }
163267
@@ -349,6 +453,6 @@ private static List<Class<?>> getParameterTypes(Class<?>... types)
349453
350454 public List <SqlFunction > getFunctions ()
351455 {
352- return ImmutableList .copyOf (functions );
456+ return ImmutableList .copyOf (functions . stream (). map ( FunctionWithMethodHandle :: getFunction ). collect ( toList ()) );
353457 }
354458}
0 commit comments