1616import org .apache .calcite .rex .RexCall ;
1717import org .apache .calcite .rex .RexNode ;
1818import org .apache .calcite .sql .fun .SqlStdOperatorTable ;
19+ import org .apache .calcite .sql .fun .SqlTrimFunction .Flag ;
1920import org .junit .jupiter .api .Test ;
2021
2122/** Tests to reproduce #562 */
@@ -33,21 +34,22 @@ public class DuplicateFunctionUrnTest extends PlanTestBase {
3334 collection2 = SimpleExtension .load ("urn2://functions" , extensions2 );
3435 collection = collection1 .merge (collection2 );
3536
36- // Verify that the merged collection contains duplicate functions with different URNs
37+ // Verify that the merged collection contains duplicate concat functions with different URNs
3738 // This is a precondition for the tests - if this fails, the tests don't make sense
38- List <SimpleExtension .ScalarFunctionVariant > ltrimFunctions =
39- collection .scalarFunctions ().stream ().filter (f -> f .name ().equals ("ltrim " )).toList ();
39+ List <SimpleExtension .ScalarFunctionVariant > concatFunctions =
40+ collection .scalarFunctions ().stream ().filter (f -> f .name ().equals ("concat " )).toList ();
4041
41- if (ltrimFunctions .size () != 2 ) {
42+ if (concatFunctions .size () != 2 ) {
4243 throw new IllegalStateException (
43- "Expected 2 ltrim functions in merged collection, but found: " + ltrimFunctions .size ());
44+ "Expected 2 concat functions in merged collection, but found: "
45+ + concatFunctions .size ());
4446 }
4547
46- String urn1 = ltrimFunctions .get (0 ).getAnchor ().urn ();
47- String urn2 = ltrimFunctions .get (1 ).getAnchor ().urn ();
48+ String urn1 = concatFunctions .get (0 ).getAnchor ().urn ();
49+ String urn2 = concatFunctions .get (1 ).getAnchor ().urn ();
4850 if (urn1 .equals (urn2 )) {
4951 throw new IllegalStateException (
50- "Expected different URNs for the two ltrim functions, but both were: " + urn1 );
52+ "Expected different URNs for the two concat functions, but both were: " + urn1 );
5153 }
5254 } catch (IOException e ) {
5355 throw new UncheckedIOException (e );
@@ -92,21 +94,21 @@ void testMergeOrderDeterminesFunctionPrecedence() {
9294 new ScalarFunctionConverter (reverseCollection .scalarFunctions (), typeFactory );
9395
9496 RexBuilder rexBuilder = new RexBuilder (typeFactory );
95- RexNode arg1 = rexBuilder .makeLiteral ("hello" );
96- RexNode arg2 = rexBuilder .makeLiteral ("world" );
97- RexCall concatCall = (RexCall ) rexBuilder .makeCall (SqlStdOperatorTable .CONCAT , arg1 , arg2 );
97+ RexCall concatCall =
98+ (RexCall )
99+ rexBuilder .makeCall (
100+ SqlStdOperatorTable .CONCAT ,
101+ rexBuilder .makeLiteral ("hello" ),
102+ rexBuilder .makeLiteral ("world" ));
98103
99104 // Create a simple topLevelConverter that converts literals to Substrait expressions
100105 java .util .function .Function <RexNode , Expression > topLevelConverter =
101106 rexNode -> {
102- if (rexNode instanceof org .apache .calcite .rex .RexLiteral ) {
103- org .apache .calcite .rex .RexLiteral lit = (org .apache .calcite .rex .RexLiteral ) rexNode ;
104- return Expression .StrLiteral .builder ()
105- .value (lit .getValueAs (String .class ))
106- .nullable (false )
107- .build ();
108- }
109- throw new UnsupportedOperationException ("Only literals supported in test" );
107+ org .apache .calcite .rex .RexLiteral lit = (org .apache .calcite .rex .RexLiteral ) rexNode ;
108+ return Expression .StrLiteral .builder ()
109+ .value (lit .getValueAs (String .class ))
110+ .nullable (false )
111+ .build ();
110112 };
111113
112114 Optional <Expression > exprA = converterA .convert (concatCall , topLevelConverter );
@@ -125,4 +127,65 @@ void testMergeOrderDeterminesFunctionPrecedence() {
125127 funcB .declaration ().getAnchor ().urn (),
126128 "converterB should use last concat function (from collection1)" );
127129 }
130+
131+ @ Test
132+ void testLtrimMergeOrderWithDefaultExtensions () {
133+ // This test verifies precedence between a custom ltrim (from collection2 with
134+ // extension:com.domain:string) and the default extension catalog's ltrim
135+ // (extension:io.substrait:functions_string).
136+ // The FunctionConverter uses a "last-wins" strategy.
137+
138+ // Merge default extensions with collection2 - collection2's ltrim should be last
139+ SimpleExtension .ExtensionCollection defaultWithCustom = extensions .merge (collection2 );
140+
141+ // Merge collection2 with default extensions - default ltrim should be last
142+ SimpleExtension .ExtensionCollection customWithDefault = collection2 .merge (extensions );
143+
144+ ScalarFunctionConverter converterA =
145+ new ScalarFunctionConverter (defaultWithCustom .scalarFunctions (), typeFactory );
146+ ScalarFunctionConverter converterB =
147+ new ScalarFunctionConverter (customWithDefault .scalarFunctions (), typeFactory );
148+
149+ // Create a TRIM(LEADING ' ' FROM 'test') call which uses TrimFunctionMapper to map to ltrim
150+ RexBuilder rexBuilder = new RexBuilder (typeFactory );
151+ RexCall trimCall =
152+ (RexCall )
153+ rexBuilder .makeCall (
154+ SqlStdOperatorTable .TRIM ,
155+ rexBuilder .makeFlag (Flag .LEADING ),
156+ rexBuilder .makeLiteral (" " ),
157+ rexBuilder .makeLiteral ("test" ));
158+
159+ java .util .function .Function <RexNode , Expression > topLevelConverter =
160+ rexNode -> {
161+ org .apache .calcite .rex .RexLiteral lit = (org .apache .calcite .rex .RexLiteral ) rexNode ;
162+ Object value = lit .getValue ();
163+ if (value == null ) {
164+ return Expression .StrLiteral .builder ().value ("" ).nullable (true ).build ();
165+ }
166+ // Convert any literal value to string
167+ return Expression .StrLiteral .builder ().value (value .toString ()).nullable (false ).build ();
168+ };
169+
170+ Optional <Expression > exprA = converterA .convert (trimCall , topLevelConverter );
171+ Optional <Expression > exprB = converterB .convert (trimCall , topLevelConverter );
172+
173+ // Both should successfully convert
174+ assertNotNull (exprA );
175+ assertNotNull (exprB );
176+
177+ Expression .ScalarFunctionInvocation funcA = (Expression .ScalarFunctionInvocation ) exprA .get ();
178+ // converterA should use collection2's custom ltrim (last)
179+ assertEquals (
180+ "extension:com.domain:string" ,
181+ funcA .declaration ().getAnchor ().urn (),
182+ "converterA should use last ltrim (custom from collection2)" );
183+
184+ Expression .ScalarFunctionInvocation funcB = (Expression .ScalarFunctionInvocation ) exprB .get ();
185+ // converterB should use default extensions' ltrim (last)
186+ assertEquals (
187+ "extension:io.substrait:functions_string" ,
188+ funcB .declaration ().getAnchor ().urn (),
189+ "converterB should use last ltrim (from default extensions)" );
190+ }
128191}
0 commit comments