22
33import com .google .common .collect .ArrayListMultimap ;
44import com .google .common .collect .ImmutableList ;
5- import com .google .common .collect .ImmutableMap ;
5+ import com .google .common .collect .ImmutableListMultimap ;
6+ import com .google .common .collect .ListMultimap ;
67import com .google .common .collect .Multimap ;
78import com .google .common .collect .Multimaps ;
89import com .google .common .collect .Streams ;
@@ -155,7 +156,7 @@ protected class FunctionFinder {
155156 private final String substraitName ;
156157 private final SqlOperator operator ;
157158 private final List <F > functions ;
158- private final Map <String , F > directMap ;
159+ private final ListMultimap <String , F > directMap ;
159160 private final Optional <SingularArgumentMatcher <F >> singularInputType ;
160161 private final Util .IntRange argRange ;
161162
@@ -168,7 +169,7 @@ public FunctionFinder(String substraitName, SqlOperator operator, List<F> functi
168169 functions .stream ().mapToInt (t -> t .getRange ().getStartInclusive ()).min ().getAsInt (),
169170 functions .stream ().mapToInt (t -> t .getRange ().getEndExclusive ()).max ().getAsInt ());
170171 this .singularInputType = getSingularInputType (functions );
171- ImmutableMap .Builder <String , F > directMap = ImmutableMap .builder ();
172+ ImmutableListMultimap .Builder <String , F > directMap = ImmutableListMultimap .builder ();
172173 for (F func : functions ) {
173174 String key = func .key ();
174175 directMap .put (key , func );
@@ -349,6 +350,9 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
349350 * Not enough context here to construct a substrait EnumArg.
350351 * Once a FunctionVariant is resolved we can map the String Literal
351352 * to a EnumArg.
353+ *
354+ * Note that if there are multiple registered function extensions which can match a particular Call,
355+ * the last one added to the extension collection will be matched.
352356 */
353357 List <RexNode > operandsList = call .getOperands ().collect (Collectors .toList ());
354358 List <Expression > operands =
@@ -369,7 +373,13 @@ public Optional<T> attemptMatch(C call, Function<RexNode, Expression> topLevelCo
369373 .findFirst ();
370374
371375 if (directMatchKey .isPresent ()) {
372- F variant = directMap .get (directMatchKey .get ());
376+ List <F > variants = directMap .get (directMatchKey .get ());
377+ if (variants .isEmpty ()) {
378+
379+ return Optional .empty ();
380+ }
381+
382+ F variant = variants .get (variants .size () - 1 );
373383 variant .validateOutputType (operands , outputType );
374384 List <FunctionArg > funcArgs =
375385 IntStream .range (0 , operandsList .size ())
0 commit comments