Skip to content

Commit d1a29b3

Browse files
committed
test: simplify tests and add special test for overlapping default ext
Added a test for ltrim to ensure there is no issue specifically with the default extension collection functions (which have some special handling in isthmus) and an introduced ltrim function.
1 parent 03b0f24 commit d1a29b3

File tree

2 files changed

+82
-28
lines changed

2 files changed

+82
-28
lines changed

isthmus/src/test/java/io/substrait/isthmus/DuplicateFunctionUrnTest.java

Lines changed: 82 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import org.apache.calcite.rex.RexCall;
1717
import org.apache.calcite.rex.RexNode;
1818
import org.apache.calcite.sql.fun.SqlStdOperatorTable;
19+
import org.apache.calcite.sql.fun.SqlTrimFunction.Flag;
1920
import org.junit.jupiter.api.Test;
2021

2122
/** Tests to reproduce #562 */
@@ -33,21 +34,22 @@ public class DuplicateFunctionUrnTest extends PlanTestBase {
3334
collection2 = SimpleExtension.load("urn2://functions", extensions2);
3435
collection = collection1.merge(collection2);
3536

36-
// Verify that the merged collection contains duplicate functions with different URNs
37+
// Verify that the merged collection contains duplicate concat functions with different URNs
3738
// This is a precondition for the tests - if this fails, the tests don't make sense
38-
List<SimpleExtension.ScalarFunctionVariant> ltrimFunctions =
39-
collection.scalarFunctions().stream().filter(f -> f.name().equals("ltrim")).toList();
39+
List<SimpleExtension.ScalarFunctionVariant> concatFunctions =
40+
collection.scalarFunctions().stream().filter(f -> f.name().equals("concat")).toList();
4041

41-
if (ltrimFunctions.size() != 2) {
42+
if (concatFunctions.size() != 2) {
4243
throw new IllegalStateException(
43-
"Expected 2 ltrim functions in merged collection, but found: " + ltrimFunctions.size());
44+
"Expected 2 concat functions in merged collection, but found: "
45+
+ concatFunctions.size());
4446
}
4547

46-
String urn1 = ltrimFunctions.get(0).getAnchor().urn();
47-
String urn2 = ltrimFunctions.get(1).getAnchor().urn();
48+
String urn1 = concatFunctions.get(0).getAnchor().urn();
49+
String urn2 = concatFunctions.get(1).getAnchor().urn();
4850
if (urn1.equals(urn2)) {
4951
throw new IllegalStateException(
50-
"Expected different URNs for the two ltrim functions, but both were: " + urn1);
52+
"Expected different URNs for the two concat functions, but both were: " + urn1);
5153
}
5254
} catch (IOException e) {
5355
throw new UncheckedIOException(e);
@@ -92,21 +94,21 @@ void testMergeOrderDeterminesFunctionPrecedence() {
9294
new ScalarFunctionConverter(reverseCollection.scalarFunctions(), typeFactory);
9395

9496
RexBuilder rexBuilder = new RexBuilder(typeFactory);
95-
RexNode arg1 = rexBuilder.makeLiteral("hello");
96-
RexNode arg2 = rexBuilder.makeLiteral("world");
97-
RexCall concatCall = (RexCall) rexBuilder.makeCall(SqlStdOperatorTable.CONCAT, arg1, arg2);
97+
RexCall concatCall =
98+
(RexCall)
99+
rexBuilder.makeCall(
100+
SqlStdOperatorTable.CONCAT,
101+
rexBuilder.makeLiteral("hello"),
102+
rexBuilder.makeLiteral("world"));
98103

99104
// Create a simple topLevelConverter that converts literals to Substrait expressions
100105
java.util.function.Function<RexNode, Expression> topLevelConverter =
101106
rexNode -> {
102-
if (rexNode instanceof org.apache.calcite.rex.RexLiteral) {
103-
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
104-
return Expression.StrLiteral.builder()
105-
.value(lit.getValueAs(String.class))
106-
.nullable(false)
107-
.build();
108-
}
109-
throw new UnsupportedOperationException("Only literals supported in test");
107+
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
108+
return Expression.StrLiteral.builder()
109+
.value(lit.getValueAs(String.class))
110+
.nullable(false)
111+
.build();
110112
};
111113

112114
Optional<Expression> exprA = converterA.convert(concatCall, topLevelConverter);
@@ -125,4 +127,65 @@ void testMergeOrderDeterminesFunctionPrecedence() {
125127
funcB.declaration().getAnchor().urn(),
126128
"converterB should use last concat function (from collection1)");
127129
}
130+
131+
@Test
132+
void testLtrimMergeOrderWithDefaultExtensions() {
133+
// This test verifies precedence between a custom ltrim (from collection2 with
134+
// extension:com.domain:string) and the default extension catalog's ltrim
135+
// (extension:io.substrait:functions_string).
136+
// The FunctionConverter uses a "last-wins" strategy.
137+
138+
// Merge default extensions with collection2 - collection2's ltrim should be last
139+
SimpleExtension.ExtensionCollection defaultWithCustom = extensions.merge(collection2);
140+
141+
// Merge collection2 with default extensions - default ltrim should be last
142+
SimpleExtension.ExtensionCollection customWithDefault = collection2.merge(extensions);
143+
144+
ScalarFunctionConverter converterA =
145+
new ScalarFunctionConverter(defaultWithCustom.scalarFunctions(), typeFactory);
146+
ScalarFunctionConverter converterB =
147+
new ScalarFunctionConverter(customWithDefault.scalarFunctions(), typeFactory);
148+
149+
// Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim
150+
RexBuilder rexBuilder = new RexBuilder(typeFactory);
151+
RexCall trimCall =
152+
(RexCall)
153+
rexBuilder.makeCall(
154+
SqlStdOperatorTable.TRIM,
155+
rexBuilder.makeFlag(Flag.LEADING),
156+
rexBuilder.makeLiteral(" "),
157+
rexBuilder.makeLiteral("test"));
158+
159+
java.util.function.Function<RexNode, Expression> topLevelConverter =
160+
rexNode -> {
161+
org.apache.calcite.rex.RexLiteral lit = (org.apache.calcite.rex.RexLiteral) rexNode;
162+
Object value = lit.getValue();
163+
if (value == null) {
164+
return Expression.StrLiteral.builder().value("").nullable(true).build();
165+
}
166+
// Convert any literal value to string
167+
return Expression.StrLiteral.builder().value(value.toString()).nullable(false).build();
168+
};
169+
170+
Optional<Expression> exprA = converterA.convert(trimCall, topLevelConverter);
171+
Optional<Expression> exprB = converterB.convert(trimCall, topLevelConverter);
172+
173+
// Both should successfully convert
174+
assertNotNull(exprA);
175+
assertNotNull(exprB);
176+
177+
Expression.ScalarFunctionInvocation funcA = (Expression.ScalarFunctionInvocation) exprA.get();
178+
// converterA should use collection2's custom ltrim (last)
179+
assertEquals(
180+
"extension:com.domain:string",
181+
funcA.declaration().getAnchor().urn(),
182+
"converterA should use last ltrim (custom from collection2)");
183+
184+
Expression.ScalarFunctionInvocation funcB = (Expression.ScalarFunctionInvocation) exprB.get();
185+
// converterB should use default extensions' ltrim (last)
186+
assertEquals(
187+
"extension:io.substrait:functions_string",
188+
funcB.declaration().getAnchor().urn(),
189+
"converterB should use last ltrim (from default extensions)");
190+
}
128191
}

isthmus/src/test/resources/extensions/functions_duplicate_urn1.yaml

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,6 @@
33
urn: extension:io.substrait:functions_string
44

55
scalar_functions:
6-
- name: "ltrim"
7-
description: "left trim from standard functions"
8-
impls:
9-
- args:
10-
- name: str
11-
value: string
12-
- name: chars
13-
value: string
14-
return: string
156
- name: "concat"
167
description: "concatenate strings"
178
impls:

0 commit comments

Comments
 (0)