Skip to content

Commit 450c28e

Browse files
committed
not limit argument type for hive simple udf
1 parent fd0b32c commit 450c28e

File tree

2 files changed

+4
-22
lines changed

2 files changed

+4
-22
lines changed

sql/hive/src/main/scala/org/apache/spark/sql/hive/HiveInspectors.scala

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,15 +137,15 @@ private[hive] trait HiveInspectors {
137137

138138
/** Converts native catalyst types to the types expected by Hive */
139139
def wrap(a: Any): AnyRef = a match {
140-
case s: String => new hadoopIo.Text(s) // TODO why should be Text?
140+
case s: String => s: java.lang.String
141141
case i: Int => i: java.lang.Integer
142142
case b: Boolean => b: java.lang.Boolean
143143
case f: Float => f: java.lang.Float
144144
case d: Double => d: java.lang.Double
145145
case l: Long => l: java.lang.Long
146146
case l: Short => l: java.lang.Short
147147
case l: Byte => l: java.lang.Byte
148-
case b: BigDecimal => b.bigDecimal
148+
case b: BigDecimal => new HiveDecimal(b.underlying())
149149
case b: Array[Byte] => b
150150
case t: java.sql.Timestamp => t
151151
case s: Seq[_] => seqAsJavaList(s.map(wrap))

sql/hive/src/main/scala/org/apache/spark/sql/hive/hiveUdfs.scala

Lines changed: 2 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -51,19 +51,7 @@ private[hive] abstract class HiveFunctionRegistry
5151
val functionClassName = functionInfo.getFunctionClass.getName
5252

5353
if (classOf[UDF].isAssignableFrom(functionInfo.getFunctionClass)) {
54-
val function = functionInfo.getFunctionClass.newInstance().asInstanceOf[UDF]
55-
val method = function.getResolver.getEvalMethod(children.map(_.dataType.toTypeInfo))
56-
57-
val expectedDataTypes = method.getParameterTypes.map(javaClassToDataType)
58-
59-
HiveSimpleUdf(
60-
functionClassName,
61-
children.zip(expectedDataTypes).map {
62-
case (e, NullType) => e
63-
case (e, t) if (e.dataType == t) => e
64-
case (e, t) => Cast(e, t)
65-
}
66-
)
54+
HiveSimpleUdf(functionClassName, children)
6755
} else if (classOf[GenericUDF].isAssignableFrom(functionInfo.getFunctionClass)) {
6856
HiveGenericUdf(functionClassName, children)
6957
} else if (
@@ -117,15 +105,9 @@ private[hive] case class HiveSimpleUdf(functionClassName: String, children: Seq[
117105
@transient
118106
lazy val dataType = javaClassToDataType(method.getReturnType)
119107

120-
def catalystToHive(value: Any): Object = value match {
121-
// TODO need more types here? or can we use wrap()
122-
case bd: BigDecimal => new HiveDecimal(bd.underlying())
123-
case d => d.asInstanceOf[Object]
124-
}
125-
126108
// TODO: Finish input output types.
127109
override def eval(input: Row): Any = {
128-
val evaluatedChildren = children.map(c => catalystToHive(c.eval(input)))
110+
val evaluatedChildren = children.map(c => wrap(c.eval(input)))
129111

130112
unwrap(FunctionRegistry.invoke(method, function, conversionHelper
131113
.convertIfNecessary(evaluatedChildren: _*): _*))

0 commit comments

Comments
 (0)