Skip to content

Commit faf74ad

Browse files
[SPARK-50032][SQL] Allow use of fully qualified collation name
### What changes were proposed in this pull request? In this PR collations can now be identified by their fully qualified name, as per the collation project plan. The `Collation` expression has been changed to always return fully qualified name. Currently we only support predefined collations. ### Why are the changes needed? Make collation names behave as per the project spec. ### Does this PR introduce _any_ user-facing change? Yes. Two user-facing changes are made: 1. Collation expression now returns fully qualified name: ```sql select collation('a' collate utf8_lcase) -- returns `SYSTEM.BUILTIN.UTF8_LCASE` ``` 2. Collations can now be identified by their full qualified name: ```sql select contains('a' collate system.builtin.utf8_lcase, 'A') -- returns true ``` ### How was this patch tested? New tests in this PR. ### Was this patch authored or co-authored using generative AI tooling? No. Closes #48546 from stevomitric/stevomitric/fully-qualified-name. Lead-authored-by: Stevo Mitric <[email protected]> Co-authored-by: Wenchen Fan <[email protected]> Signed-off-by: Wenchen Fan <[email protected]>
1 parent 3fab712 commit faf74ad

File tree

20 files changed

+347
-165
lines changed

20 files changed

+347
-165
lines changed

common/unsafe/src/main/java/org/apache/spark/sql/catalyst/util/CollationFactory.java

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -415,18 +415,6 @@ private static Collation fetchCollation(int collationId) {
415415
}
416416
}
417417

418-
/**
419-
* Method for constructing errors thrown on providing invalid collation name.
420-
*/
421-
protected static SparkException collationInvalidNameException(String collationName) {
422-
Map<String, String> params = new HashMap<>();
423-
final int maxSuggestions = 3;
424-
params.put("collationName", collationName);
425-
params.put("proposals", getClosestSuggestionsOnInvalidName(collationName, maxSuggestions));
426-
return new SparkException("COLLATION_INVALID_NAME",
427-
SparkException.constructMessageParams(params), null);
428-
}
429-
430418
private static int collationNameToId(String collationName) throws SparkException {
431419
// Collation names provided by user are treated as case-insensitive.
432420
String collationNameUpper = collationName.toUpperCase();
@@ -1185,6 +1173,52 @@ public static int collationNameToId(String collationName) throws SparkException
11851173
return Collation.CollationSpec.collationNameToId(collationName);
11861174
}
11871175

1176+
/**
1177+
* Returns the resolved fully qualified collation name.
1178+
*/
1179+
public static String resolveFullyQualifiedName(String[] collationName) throws SparkException {
1180+
// If collation name has only one part, then we don't need to do any name resolution.
1181+
if (collationName.length == 1) return collationName[0];
1182+
else {
1183+
// Currently we only support builtin collation names with fixed catalog `SYSTEM` and
1184+
// schema `BUILTIN`.
1185+
if (collationName.length != 3 ||
1186+
!CollationFactory.CATALOG.equalsIgnoreCase(collationName[0]) ||
1187+
!CollationFactory.SCHEMA.equalsIgnoreCase(collationName[1])) {
1188+
// Throw exception with original (before case conversion) collation name.
1189+
throw CollationFactory.collationInvalidNameException(
1190+
collationName.length != 0 ? collationName[collationName.length - 1] : "");
1191+
}
1192+
return collationName[2];
1193+
}
1194+
}
1195+
1196+
/**
1197+
* Method for constructing errors thrown on providing invalid collation name.
1198+
*/
1199+
public static SparkException collationInvalidNameException(String collationName) {
1200+
Map<String, String> params = new HashMap<>();
1201+
final int maxSuggestions = 3;
1202+
params.put("collationName", collationName);
1203+
params.put("proposals", getClosestSuggestionsOnInvalidName(collationName, maxSuggestions));
1204+
return new SparkException("COLLATION_INVALID_NAME",
1205+
SparkException.constructMessageParams(params), null);
1206+
}
1207+
1208+
1209+
1210+
/**
1211+
* Returns the fully qualified collation name for the given collation ID.
1212+
*/
1213+
public static String fullyQualifiedName(int collationId) {
1214+
Collation.CollationSpec.DefinitionOrigin definitionOrigin =
1215+
Collation.CollationSpec.getDefinitionOrigin(collationId);
1216+
// Currently only predefined collations are supported.
1217+
assert definitionOrigin == Collation.CollationSpec.DefinitionOrigin.PREDEFINED;
1218+
return String.format("%s.%s.%s", CATALOG, SCHEMA,
1219+
Collation.CollationSpec.fetchCollation(collationId).collationName);
1220+
}
1221+
11881222
public static boolean isCaseInsensitive(int collationId) {
11891223
return Collation.CollationSpecICU.fromCollationId(collationId).caseSensitivity ==
11901224
Collation.CollationSpecICU.CaseSensitivity.CI;

python/pyspark/sql/functions/builtin.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16756,12 +16756,12 @@ def collation(col: "ColumnOrName") -> Column:
1675616756
Examples
1675716757
--------
1675816758
>>> df = spark.createDataFrame([('name',)], ['dt'])
16759-
>>> df.select(collation('dt').alias('collation')).show()
16760-
+-----------+
16761-
| collation|
16762-
+-----------+
16763-
|UTF8_BINARY|
16764-
+-----------+
16759+
>>> df.select(collation('dt').alias('collation')).show(truncate=False)
16760+
+--------------------------+
16761+
|collation |
16762+
+--------------------------+
16763+
|SYSTEM.BUILTIN.UTF8_BINARY|
16764+
+--------------------------+
1676516765
"""
1676616766
return _invoke_function_over_columns("collation", col)
1676716767

python/pyspark/sql/tests/test_functions.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -456,7 +456,7 @@ def test_string_functions(self):
456456
def test_collation(self):
457457
df = self.spark.createDataFrame([("a",), ("b",)], ["name"])
458458
actual = df.select(F.collation(F.collate("name", "UNICODE"))).distinct().collect()
459-
self.assertEqual([Row("UNICODE")], actual)
459+
self.assertEqual([Row("SYSTEM.BUILTIN.UNICODE")], actual)
460460

461461
def test_try_make_interval(self):
462462
df = self.spark.createDataFrame([(2147483647,)], ["num"])

sql/api/src/main/antlr4/org/apache/spark/sql/catalyst/parser/SqlBaseParser.g4

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1233,7 +1233,7 @@ colPosition
12331233
;
12341234

12351235
collateClause
1236-
: COLLATE collationName=identifier
1236+
: COLLATE collationName=multipartIdentifier
12371237
;
12381238

12391239
type

sql/api/src/main/scala/org/apache/spark/sql/catalyst/parser/DataTypeAstBuilder.scala

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,14 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
5757
}
5858
}
5959

60+
/**
61+
* Create a multi-part identifier.
62+
*/
63+
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
64+
withOrigin(ctx) {
65+
ctx.parts.asScala.map(_.getText).toSeq
66+
}
67+
6068
/**
6169
* Resolve/create a primitive type.
6270
*/
@@ -78,8 +86,9 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
7886
typeCtx.children.asScala.toSeq match {
7987
case Seq(_) => StringType
8088
case Seq(_, ctx: CollateClauseContext) =>
81-
val collationName = visitCollateClause(ctx)
82-
val collationId = CollationFactory.collationNameToId(collationName)
89+
val collationNameParts = visitCollateClause(ctx).toArray
90+
val collationId = CollationFactory.collationNameToId(
91+
CollationFactory.resolveFullyQualifiedName(collationNameParts))
8392
StringType(collationId)
8493
}
8594
case (CHARACTER | CHAR, length :: Nil) => CharType(length.getText.toInt)
@@ -219,8 +228,8 @@ class DataTypeAstBuilder extends SqlBaseParserBaseVisitor[AnyRef] {
219228
/**
220229
* Returns a collation name.
221230
*/
222-
override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
223-
ctx.identifier.getText
231+
override def visitCollateClause(ctx: CollateClauseContext): Seq[String] = withOrigin(ctx) {
232+
visitMultipartIdentifier(ctx.collationName)
224233
}
225234

226235
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/analysis/Analyzer.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,6 +299,7 @@ class Analyzer(override val catalogManager: CatalogManager) extends RuleExecutor
299299
ResolveFieldNameAndPosition ::
300300
AddMetadataColumns ::
301301
DeduplicateRelations ::
302+
ResolveCollationName ::
302303
new ResolveReferences(catalogManager) ::
303304
// Please do not insert any other rules in between. See the TODO comments in rule
304305
// ResolveLateralColumnAliasReference for more details.
Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.sql.catalyst.analysis
19+
20+
import org.apache.spark.sql.catalyst.expressions._
21+
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
22+
import org.apache.spark.sql.catalyst.rules.Rule
23+
import org.apache.spark.sql.catalyst.trees.TreePattern.UNRESOLVED_COLLATION
24+
import org.apache.spark.sql.catalyst.util.CollationFactory
25+
26+
/**
27+
* Resolves fully qualified collation name and replaces [[UnresolvedCollation]] with
28+
* [[ResolvedCollation]].
29+
*/
30+
object ResolveCollationName extends Rule[LogicalPlan] {
31+
def apply(plan: LogicalPlan): LogicalPlan =
32+
plan.resolveExpressionsWithPruning(_.containsPattern(UNRESOLVED_COLLATION), ruleId) {
33+
case UnresolvedCollation(collationName) =>
34+
ResolvedCollation(CollationFactory.resolveFullyQualifiedName(collationName.toArray))
35+
}
36+
}

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/expressions/collationExpressions.scala

Lines changed: 61 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -17,10 +17,12 @@
1717

1818
package org.apache.spark.sql.catalyst.expressions
1919

20+
import org.apache.spark.SparkException
2021
import org.apache.spark.sql.catalyst.InternalRow
21-
import org.apache.spark.sql.catalyst.analysis.ExpressionBuilder
22+
import org.apache.spark.sql.catalyst.analysis.{ExpressionBuilder, UnresolvedException}
2223
import org.apache.spark.sql.catalyst.expressions.codegen._
23-
import org.apache.spark.sql.catalyst.util.CollationFactory
24+
import org.apache.spark.sql.catalyst.trees.TreePattern.{TreePattern, UNRESOLVED_COLLATION}
25+
import org.apache.spark.sql.catalyst.util.{AttributeNameParser, CollationFactory}
2426
import org.apache.spark.sql.errors.QueryCompilationErrors
2527
import org.apache.spark.sql.internal.SQLConf
2628
import org.apache.spark.sql.internal.types.StringTypeWithCollation
@@ -37,7 +39,7 @@ import org.apache.spark.sql.types._
3739
examples = """
3840
Examples:
3941
> SELECT COLLATION('Spark SQL' _FUNC_ UTF8_LCASE);
40-
UTF8_LCASE
42+
SYSTEM.BUILTIN.UTF8_LCASE
4143
""",
4244
since = "4.0.0",
4345
group = "string_funcs")
@@ -56,7 +58,8 @@ object CollateExpressionBuilder extends ExpressionBuilder {
5658
evalCollation.toString.toUpperCase().contains("TRIM")) {
5759
throw QueryCompilationErrors.trimCollationNotEnabledError()
5860
}
59-
Collate(e, evalCollation.toString)
61+
Collate(e, UnresolvedCollation(
62+
AttributeNameParser.parseAttributeName(evalCollation.toString)))
6063
}
6164
case (_: StringType, false) => throw QueryCompilationErrors.nonFoldableArgumentError(
6265
funcName, "collationName", StringType)
@@ -73,24 +76,63 @@ object CollateExpressionBuilder extends ExpressionBuilder {
7376
* This function is pass-through, it will not modify the input data.
7477
* Only type metadata will be updated.
7578
*/
76-
case class Collate(child: Expression, collationName: String)
77-
extends UnaryExpression with ExpectsInputTypes {
78-
private val collationId = CollationFactory.collationNameToId(collationName)
79-
override def dataType: DataType = StringType(collationId)
79+
case class Collate(child: Expression, collation: Expression)
80+
extends BinaryExpression with ExpectsInputTypes {
81+
override def left: Expression = child
82+
override def right: Expression = collation
83+
override def dataType: DataType = collation.dataType
8084
override def inputTypes: Seq[AbstractDataType] =
81-
Seq(StringTypeWithCollation(supportsTrimCollation = true))
82-
83-
override protected def withNewChildInternal(
84-
newChild: Expression): Expression = copy(newChild)
85+
Seq(StringTypeWithCollation(supportsTrimCollation = true), AnyDataType)
8586

8687
override def eval(row: InternalRow): Any = child.eval(row)
8788

88-
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode =
89-
defineCodeGen(ctx, ev, (in) => in)
89+
/** Just a simple passthrough for code generation. */
90+
override def genCode(ctx: CodegenContext): ExprCode = child.genCode(ctx)
91+
override protected def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
92+
throw SparkException.internalError("Collate.doGenCode should not be called.")
93+
}
94+
95+
override def sql: String = s"$prettyName(${child.sql}, $collation)"
96+
97+
override def toString: String =
98+
s"$prettyName($child, $collation)"
99+
100+
override protected def withNewChildrenInternal(
101+
newLeft: Expression, newRight: Expression): Expression =
102+
copy(child = newLeft, collation = newRight)
103+
104+
override def foldable: Boolean = child.foldable
105+
}
106+
107+
/**
108+
* An expression that marks an unresolved collation name.
109+
*
110+
* This class is used to represent a collation name that has not yet been resolved from a fully
111+
* qualified collation name. It is used during the analysis phase, where the collation name is
112+
* specified but not yet validated or resolved.
113+
*/
114+
case class UnresolvedCollation(collationName: Seq[String])
115+
extends LeafExpression with Unevaluable {
116+
override def dataType: DataType = throw new UnresolvedException("dataType")
117+
118+
override def nullable: Boolean = false
119+
120+
override lazy val resolved: Boolean = false
121+
122+
final override val nodePatterns: Seq[TreePattern] = Seq(UNRESOLVED_COLLATION)
123+
}
124+
125+
/**
126+
* An expression that represents a resolved collation name.
127+
*/
128+
case class ResolvedCollation(collationName: String) extends LeafExpression with Unevaluable {
129+
override def nullable: Boolean = false
130+
131+
override def dataType: DataType = StringType(CollationFactory.collationNameToId(collationName))
90132

91-
override def sql: String = s"$prettyName(${child.sql}, $collationName)"
133+
override def toString: String = collationName
92134

93-
override def toString: String = s"$prettyName($child, $collationName)"
135+
override def sql: String = collationName
94136
}
95137

96138
// scalastyle:off line.contains.tab
@@ -103,7 +145,7 @@ case class Collate(child: Expression, collationName: String)
103145
examples = """
104146
Examples:
105147
> SELECT _FUNC_('Spark SQL');
106-
UTF8_BINARY
148+
SYSTEM.BUILTIN.UTF8_BINARY
107149
""",
108150
since = "4.0.0",
109151
group = "string_funcs")
@@ -113,8 +155,8 @@ case class Collation(child: Expression)
113155
override protected def withNewChildInternal(newChild: Expression): Collation = copy(newChild)
114156
override lazy val replacement: Expression = {
115157
val collationId = child.dataType.asInstanceOf[StringType].collationId
116-
val collationName = CollationFactory.fetchCollation(collationId).collationName
117-
Literal.create(collationName, SQLConf.get.defaultStringType)
158+
val fullyQualifiedCollationName = CollationFactory.fullyQualifiedName(collationId)
159+
Literal.create(fullyQualifiedCollationName, SQLConf.get.defaultStringType)
118160
}
119161
override def inputTypes: Seq[AbstractDataType] =
120162
Seq(StringTypeWithCollation(supportsTrimCollation = true))

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/parser/AstBuilder.scala

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2286,14 +2286,6 @@ class AstBuilder extends DataTypeAstBuilder
22862286
FunctionIdentifier(ctx.function.getText, Option(ctx.db).map(_.getText))
22872287
}
22882288

2289-
/**
2290-
* Create a multi-part identifier.
2291-
*/
2292-
override def visitMultipartIdentifier(ctx: MultipartIdentifierContext): Seq[String] =
2293-
withOrigin(ctx) {
2294-
ctx.parts.asScala.map(_.getText).toSeq
2295-
}
2296-
22972289
/* ********************************************************************************************
22982290
* Expression parsing
22992291
* ******************************************************************************************** */
@@ -2706,15 +2698,16 @@ class AstBuilder extends DataTypeAstBuilder
27062698
*/
27072699
override def visitCollate(ctx: CollateContext): Expression = withOrigin(ctx) {
27082700
val collationName = visitCollateClause(ctx.collateClause())
2709-
Collate(expression(ctx.primaryExpression), collationName)
2701+
2702+
Collate(expression(ctx.primaryExpression), UnresolvedCollation(collationName))
27102703
}
27112704

2712-
override def visitCollateClause(ctx: CollateClauseContext): String = withOrigin(ctx) {
2713-
val collationName = ctx.collationName.getText
2714-
if (!SQLConf.get.trimCollationEnabled && collationName.toUpperCase().contains("TRIM")) {
2705+
override def visitCollateClause(ctx: CollateClauseContext): Seq[String] = withOrigin(ctx) {
2706+
val collationName = visitMultipartIdentifier(ctx.collationName)
2707+
if (!SQLConf.get.trimCollationEnabled && collationName.last.toUpperCase().contains("TRIM")) {
27152708
throw QueryCompilationErrors.trimCollationNotEnabledError()
27162709
}
2717-
ctx.identifier.getText
2710+
collationName
27182711
}
27192712

27202713
/**

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/rules/RuleIdCollection.scala

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ object RuleIdCollection {
5151
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAggregateFunctions" ::
5252
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveAliases" ::
5353
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveBinaryArithmetic" ::
54+
"org.apache.spark.sql.catalyst.analysis.ResolveCollationName" ::
5455
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveDeserializer" ::
5556
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveEncodersInUDF" ::
5657
"org.apache.spark.sql.catalyst.analysis.Analyzer$ResolveFunctions" ::

0 commit comments

Comments
 (0)