Skip to content

Commit 55bdf2e

Browse files
committed
Support multiple annotation based definitions of scalar function
For datatypes with multiple binary representations we need to define same function multiple times depending on actal types of parameters. E.g. add (Slice a, Slice b) and add (long a, long b) represent + operator for DECIMAL type. First version for longer DECIMAL values (represented internally as Slice) and second, faster version for shorter decimals (represented internally as long). This patch allows multiple versions of same scalar udf/operator to be defined using annotation syntax. Functions with same name are grouped together and exposed as single ParametricFunction which internally does dispatching to proper MethodHandle based on actual parameter types at runtime.
1 parent d9252cb commit 55bdf2e

File tree

4 files changed

+238
-53
lines changed

4 files changed

+238
-53
lines changed

presto-main/src/main/java/com/facebook/presto/metadata/FunctionListBuilder.java

Lines changed: 114 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
package com.facebook.presto.metadata;
1515

1616
import com.facebook.presto.operator.Description;
17-
import com.facebook.presto.operator.aggregation.GenericAggregationFunctionFactory;
1817
import com.facebook.presto.operator.aggregation.InternalAggregationFunction;
1918
import com.facebook.presto.operator.scalar.JsonPath;
2019
import com.facebook.presto.operator.scalar.ScalarFunction;
@@ -46,19 +45,26 @@
4645
import java.util.ArrayList;
4746
import java.util.Arrays;
4847
import java.util.List;
48+
import java.util.Map;
49+
import java.util.Objects;
50+
import java.util.Optional;
4951
import java.util.Set;
52+
import java.util.stream.Collectors;
5053

5154
import static com.facebook.presto.metadata.FunctionKind.SCALAR;
5255
import static com.facebook.presto.metadata.FunctionKind.WINDOW;
5356
import static com.facebook.presto.metadata.Signature.typeParameter;
57+
import static com.facebook.presto.operator.aggregation.GenericAggregationFunctionFactory.fromAggregationDefinition;
5458
import static com.facebook.presto.spi.type.BigintType.BIGINT;
5559
import static com.facebook.presto.spi.type.TypeSignature.parseTypeSignature;
5660
import static com.google.common.base.CaseFormat.LOWER_CAMEL;
5761
import static com.google.common.base.CaseFormat.LOWER_UNDERSCORE;
5862
import static com.google.common.base.Preconditions.checkArgument;
63+
import static com.google.common.collect.Iterables.getOnlyElement;
5964
import static java.lang.invoke.MethodHandles.lookup;
6065
import static java.util.Locale.ENGLISH;
6166
import static java.util.Objects.requireNonNull;
67+
import static java.util.stream.Collectors.toList;
6268

6369
public class FunctionListBuilder
6470
{
@@ -69,7 +75,35 @@ public class FunctionListBuilder
6975
Regex.class,
7076
JsonPath.class);
7177

72-
private final List<SqlFunction> functions = new ArrayList<>();
78+
private static final class FunctionWithMethodHandle
79+
{
80+
private final SqlFunction function;
81+
private final Optional<MethodHandle> methodHandle;
82+
83+
public FunctionWithMethodHandle(SqlFunction function)
84+
{
85+
this.function = function;
86+
this.methodHandle = Optional.empty();
87+
}
88+
89+
public FunctionWithMethodHandle(SqlFunction function, MethodHandle methodHandle)
90+
{
91+
this.function = function;
92+
this.methodHandle = Optional.of(methodHandle);
93+
}
94+
95+
public Optional<MethodHandle> getMethodHandle()
96+
{
97+
return methodHandle;
98+
}
99+
100+
public SqlFunction getFunction()
101+
{
102+
return function;
103+
}
104+
}
105+
106+
private final List<FunctionWithMethodHandle> functions = new ArrayList<>();
73107
private final TypeManager typeManager;
74108

75109
public FunctionListBuilder(TypeManager typeManager)
@@ -83,14 +117,14 @@ public FunctionListBuilder window(String name, Type returnType, List<? extends T
83117
new Signature(name, WINDOW, returnType.getTypeSignature(), Lists.transform(ImmutableList.copyOf(argumentTypes), Type::getTypeSignature)),
84118
functionClass);
85119

86-
functions.add(new SqlWindowFunction(windowFunctionSupplier));
120+
addFunction(new SqlWindowFunction(windowFunctionSupplier));
87121
return this;
88122
}
89123

90124
public FunctionListBuilder window(String name, Class<? extends ValueWindowFunction> clazz, String typeVariable, String... argumentTypes)
91125
{
92126
Signature signature = new Signature(name, WINDOW, ImmutableList.of(typeParameter(typeVariable)), typeVariable, ImmutableList.copyOf(argumentTypes), false);
93-
functions.add(new SqlWindowFunction(new ReflectionWindowFunctionSupplier<>(signature, clazz)));
127+
addFunction(new SqlWindowFunction(new ReflectionWindowFunctionSupplier<>(signature, clazz)));
94128
return this;
95129
}
96130

@@ -108,29 +142,47 @@ public FunctionListBuilder aggregate(InternalAggregationFunction function)
108142
name = name.toLowerCase(ENGLISH);
109143

110144
String description = getDescription(function.getClass());
111-
functions.add(SqlAggregationFunction.create(name, description, function));
145+
addFunction(SqlAggregationFunction.create(name, description, function));
112146
return this;
113147
}
114148

115149
public FunctionListBuilder aggregate(Class<?> aggregationDefinition)
116150
{
117-
functions.addAll(GenericAggregationFunctionFactory.fromAggregationDefinition(aggregationDefinition, typeManager).listFunctions());
151+
fromAggregationDefinition(aggregationDefinition, typeManager).listFunctions().forEach(this::addFunction);
118152
return this;
119153
}
120154

121155
public FunctionListBuilder scalar(Signature signature, MethodHandle function, boolean deterministic, String description, boolean hidden, boolean nullable, List<Boolean> nullableArguments)
122156
{
123-
functions.add(SqlScalarFunction.create(signature, description, hidden, function, deterministic, nullable, nullableArguments));
157+
addFunction(SqlScalarFunction.create(signature, description, hidden, function, deterministic, nullable, nullableArguments), function);
124158
return this;
125159
}
126160

127161
private FunctionListBuilder operator(OperatorType operatorType, TypeSignature returnType, List<TypeSignature> parameterTypes, MethodHandle function, boolean nullable, List<Boolean> nullableArguments)
128162
{
129-
functions.add(SqlOperator.create(operatorType, parameterTypes, returnType, function, nullable, nullableArguments));
163+
addFunction(SqlOperator.create(operatorType, parameterTypes, returnType, function, nullable, nullableArguments), function);
130164
return this;
131165
}
132166

133167
public FunctionListBuilder scalar(Class<?> clazz)
168+
{
169+
FunctionListBuilder localFunctionListBuilder = new FunctionListBuilder(typeManager);
170+
localFunctionListBuilder.processScalarsInClass(clazz);
171+
functions.addAll(localFunctionListBuilder.functions);
172+
return this;
173+
}
174+
175+
private void addFunction(SqlFunction function)
176+
{
177+
functions.add(new FunctionWithMethodHandle(function));
178+
}
179+
180+
private void addFunction(SqlFunction function, MethodHandle methodHandle)
181+
{
182+
functions.add(new FunctionWithMethodHandle(function, methodHandle));
183+
}
184+
185+
private FunctionListBuilder processScalarsInClass(Class<?> clazz)
134186
{
135187
try {
136188
boolean foundOne = false;
@@ -143,9 +195,61 @@ public FunctionListBuilder scalar(Class<?> clazz)
143195
catch (IllegalAccessException e) {
144196
throw Throwables.propagate(e);
145197
}
198+
groupMatchingScalars();
146199
return this;
147200
}
148201

202+
private FunctionListBuilder groupMatchingScalars()
203+
{
204+
List<FunctionWithMethodHandle> newFunctions = groupMatchingScalars(functions);
205+
functions.clear();
206+
functions.addAll(newFunctions);
207+
return this;
208+
}
209+
210+
private List<FunctionWithMethodHandle> groupMatchingScalars(List<FunctionWithMethodHandle> inputFunctions)
211+
{
212+
inputFunctions.forEach(f -> checkArgument(f.getMethodHandle().isPresent(), "expected function with method handle: %s", f));
213+
inputFunctions.forEach(f -> checkArgument(f.getFunction() instanceof SqlScalarFunction, "expected scalar function: %s", f));
214+
Map<Signature, List<FunctionWithMethodHandle>> groupedFunctions = inputFunctions.stream()
215+
.collect(Collectors.groupingBy(f -> f.getFunction().getSignature()));
216+
List<FunctionWithMethodHandle> resultFunctions = new ArrayList<>();
217+
for (Map.Entry<Signature, List<FunctionWithMethodHandle>> entry : groupedFunctions.entrySet()) {
218+
List<FunctionWithMethodHandle> functionsGroup = entry.getValue();
219+
if (functionsGroup.size() == 1) {
220+
resultFunctions.add(getOnlyElement(functionsGroup));
221+
}
222+
else {
223+
resultFunctions.add(new FunctionWithMethodHandle(buildGroupingScalarWrapper(functionsGroup)));
224+
}
225+
}
226+
return resultFunctions;
227+
}
228+
229+
private SqlScalarFunction buildGroupingScalarWrapper(List<FunctionWithMethodHandle> functionsGroup)
230+
{
231+
checkArgument(functionsGroup.size() > 1, "functions group must have multiple elements");
232+
SqlFunction masterFunction = functionsGroup.get(0).getFunction();
233+
functionsGroup.forEach(f -> checkMatchingScalarsConsistent(masterFunction, f.getFunction()));
234+
235+
SqlScalarFunctionBuilder wrapperBuilder = SqlScalarFunction.builder()
236+
.signature(masterFunction.getSignature())
237+
.deterministic(masterFunction.isDeterministic())
238+
.hidden(masterFunction.isHidden())
239+
.description(masterFunction.getDescription());
240+
241+
functionsGroup.forEach(f -> wrapperBuilder.method((SqlScalarFunction) f.getFunction(), f.getMethodHandle().get()));
242+
return wrapperBuilder.build();
243+
}
244+
245+
private void checkMatchingScalarsConsistent(SqlFunction master, SqlFunction canditate)
246+
{
247+
checkArgument(canditate.getSignature().equals(master.getSignature()), "signature mismatch; %s vs. %s", master, canditate);
248+
checkArgument(canditate.isHidden() == master.isHidden(), "hidden flag mismatch; %s vs. %s", master, canditate);
249+
checkArgument(canditate.isDeterministic() == master.isDeterministic(), "deterministic flag mismatch; %s vs. %s", master, canditate);
250+
checkArgument(Objects.equals(canditate.getDescription(), master.getDescription()), "description mismatch, %s vs. %s", master, canditate);
251+
}
252+
149253
public FunctionListBuilder functions(SqlFunction... sqlFunctions)
150254
{
151255
for (SqlFunction sqlFunction : sqlFunctions) {
@@ -157,7 +261,7 @@ public FunctionListBuilder functions(SqlFunction... sqlFunctions)
157261
public FunctionListBuilder function(SqlFunction sqlFunction)
158262
{
159263
requireNonNull(sqlFunction, "parametricFunction is null");
160-
functions.add(sqlFunction);
264+
addFunction(sqlFunction);
161265
return this;
162266
}
163267

@@ -349,6 +453,6 @@ private static List<Class<?>> getParameterTypes(Class<?>... types)
349453

350454
public List<SqlFunction> getFunctions()
351455
{
352-
return ImmutableList.copyOf(functions);
456+
return ImmutableList.copyOf(functions.stream().map(FunctionWithMethodHandle::getFunction).collect(toList()));
353457
}
354458
}

presto-main/src/main/java/com/facebook/presto/metadata/PolymorphicScalarFunction.java

Lines changed: 22 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515

1616
import com.facebook.presto.metadata.SqlScalarFunctionBuilder.MethodsGroup;
1717
import com.facebook.presto.metadata.SqlScalarFunctionBuilder.SpecializeContext;
18+
import com.facebook.presto.metadata.SqlScalarFunctionBuilder.TargetMethodDelegate;
1819
import com.facebook.presto.operator.scalar.ScalarFunctionImplementation;
1920
import com.facebook.presto.spi.type.Type;
2021
import com.facebook.presto.spi.type.TypeManager;
2122
import com.facebook.presto.spi.type.TypeSignature;
22-
import com.facebook.presto.util.Reflection;
2323

2424
import java.lang.invoke.MethodHandle;
2525
import java.lang.invoke.MethodHandles;
26-
import java.lang.reflect.Method;
26+
import java.lang.invoke.MethodType;
2727
import java.util.List;
2828
import java.util.Map;
2929
import java.util.Optional;
@@ -90,10 +90,10 @@ public ScalarFunctionImplementation specialize(Map<String, Type> types, List<Typ
9090
Type resolvedReturnType = resolveReturnType(types, typeManager, calculatedReturnType);
9191
SpecializeContext context = new SpecializeContext(types, filterPresentLiterals(literalParameters), resolvedParameterTypes, resolvedReturnType, typeManager, functionRegistry);
9292

93-
Optional<Method> matchingMethod = Optional.empty();
93+
Optional<TargetMethodDelegate> matchingMethod = Optional.empty();
9494
Optional<MethodsGroup> matchingMethodsGroup = Optional.empty();
9595
for (MethodsGroup candidateMethodsGroup : methodsGroups) {
96-
for (Method candidateMethod : candidateMethodsGroup.getMethods()) {
96+
for (TargetMethodDelegate candidateMethod : candidateMethodsGroup.getMethods()) {
9797
if (matchesParameterAndReturnTypes(candidateMethod, resolvedParameterTypes, resolvedReturnType) &&
9898
predicateIsTrue(candidateMethodsGroup, context)) {
9999
if (matchingMethod.isPresent()) {
@@ -111,10 +111,14 @@ public ScalarFunctionImplementation specialize(Map<String, Type> types, List<Typ
111111
}
112112
checkState(matchingMethod.isPresent(), "no matching method for parameter types %s", parameterTypes);
113113

114-
List<Object> extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context);
115-
MethodHandle matchingMethodHandle = applyExtraParameters(matchingMethod.get(), extraParameters);
116-
117-
return new ScalarFunctionImplementation(nullableResult, nullableArguments, matchingMethodHandle, deterministic);
114+
if (matchingMethod.get().getScalarFunctionDelegate().isPresent()) {
115+
return matchingMethod.get().getScalarFunctionDelegate().get().specialize(types, parameterTypes, typeManager, functionRegistry);
116+
}
117+
else {
118+
List<Object> extraParameters = computeExtraParameters(matchingMethodsGroup.get(), context);
119+
MethodHandle matchingMethodHandle = applyExtraParameters(matchingMethod.get(), extraParameters);
120+
return new ScalarFunctionImplementation(nullableResult, nullableArguments, matchingMethodHandle, deterministic);
121+
}
118122
}
119123

120124
private Type resolveReturnType(Map<String, Type> types, TypeManager typeManager, TypeSignature calculatedReturnType)
@@ -129,19 +133,19 @@ private Type resolveReturnType(Map<String, Type> types, TypeManager typeManager,
129133
return resolvedReturnType;
130134
}
131135

132-
private boolean matchesParameterAndReturnTypes(Method method, List<Type> resolvedTypes, Type returnType)
136+
private boolean matchesParameterAndReturnTypes(TargetMethodDelegate method, List<Type> resolvedTypes, Type returnType)
133137
{
134-
checkState(method.getParameterCount() >= resolvedTypes.size(),
135-
"method %s has not enough arguments: %s (should have at least %s)", method.getName(), method.getParameterCount(), resolvedTypes.size());
138+
MethodType methodHandleType = method.getMethodHandle().type();
139+
checkState(methodHandleType.parameterCount() >= resolvedTypes.size(),
140+
"method %s has not enough arguments: %s (should have at least %s)", method.getName(), methodHandleType.parameterCount(), resolvedTypes.size());
136141

137-
Class<?>[] methodParameterJavaTypes = method.getParameterTypes();
142+
List<Class<?>> methodParameterJavaTypes = methodHandleType.parameterList();
138143
for (int i = 0; i < resolvedTypes.size(); ++i) {
139-
if (!methodParameterJavaTypes[i].equals(resolvedTypes.get(i).getJavaType())) {
144+
if (!methodParameterJavaTypes.get(i).equals(resolvedTypes.get(i).getJavaType())) {
140145
return false;
141146
}
142147
}
143-
144-
return method.getReturnType().equals(returnType.getJavaType());
148+
return methodHandleType.returnType().equals(returnType.getJavaType());
145149
}
146150

147151
private boolean onlyFirstMatchedMethodHasPredicate(MethodsGroup matchingMethodsGroup, MethodsGroup methodsGroup)
@@ -166,16 +170,15 @@ private Map<String, Long> filterPresentLiterals(Map<String, OptionalLong> boundL
166170
.collect(toMap(entry -> entry.getKey().toLowerCase(US), entry -> entry.getValue().getAsLong()));
167171
}
168172

169-
private MethodHandle applyExtraParameters(Method matchingMethod, List<Object> extraParameters)
173+
private MethodHandle applyExtraParameters(TargetMethodDelegate matchingMethod, List<Object> extraParameters)
170174
{
171175
Signature signature = getSignature();
172176
int expectedNumberOfArguments = signature.getArgumentTypes().size() + extraParameters.size();
173-
int matchingMethodParameterCount = matchingMethod.getParameterCount();
177+
int matchingMethodParameterCount = matchingMethod.getMethodHandle().type().parameterCount();
174178
checkState(matchingMethodParameterCount == expectedNumberOfArguments,
175179
"method %s has invalid number of arguments: %s (should have %s)", matchingMethod.getName(), matchingMethodParameterCount, expectedNumberOfArguments);
176180

177-
MethodHandle matchingMethodHandle = Reflection.methodHandle(matchingMethod);
178-
matchingMethodHandle = MethodHandles.insertArguments(matchingMethodHandle, signature.getArgumentTypes().size(), extraParameters.toArray());
181+
MethodHandle matchingMethodHandle = MethodHandles.insertArguments(matchingMethod.getMethodHandle(), signature.getArgumentTypes().size(), extraParameters.toArray());
179182
return matchingMethodHandle;
180183
}
181184
}

presto-main/src/main/java/com/facebook/presto/metadata/SqlScalarFunction.java

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,10 @@ public static SqlScalarFunctionBuilder builder(Class<?> clazz)
7878
return new SqlScalarFunctionBuilder(clazz);
7979
}
8080

81+
public static SqlScalarFunctionBuilder builder()
82+
{
83+
return new SqlScalarFunctionBuilder();
84+
}
8185

8286
private static class SimpleSqlScalarFunction
8387
extends SqlScalarFunction

0 commit comments

Comments
 (0)