Skip to content

Commit 8757094

Browse files
committed
Fixes #16
Detailed research revealed that we're incorrectly inferring field names if they don't confirm to javaBean specification. It turns out that we're inferring field name from getter which is not always correct. So in this commit we're adding new wrapper KStructField, which contains information not only on field name, but also information on getter name which gives us abiity to filter data not by inferred field name, but by getter name which is much more safe - there is no way kotlin will allow the scapitalized and non capitalized properties to co-exist in same class, like `Country` and `country`, so usage of getter looks more correct way to handle this situation. Signed-off-by: Pasha Finkelshteyn <[email protected]>
1 parent 91afeff commit 8757094

File tree

3 files changed

+49
-18
lines changed

3 files changed

+49
-18
lines changed

core/src/main/scala/org/apache/spark/sql/KotlinReflection.scala

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -653,10 +653,11 @@ object KotlinReflection extends KotlinReflection {
653653
val cls = dataType.cls
654654
val properties = getJavaBeanReadableProperties(cls)
655655
val fields = properties.map { prop =>
656-
val fieldName = prop.getName
657-
val maybeField = dataType.dt.fields.find(it => it.name == fieldName)
656+
657+
val maybeField = dataType.dt.fields.map(_.asInstanceOf[KStructField]).find(it => it.getterName == prop.getReadMethod.getName)
658658
if (maybeField.isEmpty)
659-
throw new IllegalArgumentException(s"Field $fieldName is not found among available fields, which are: ${dataType.dt.fields.map(_.name).mkString(", ")}")
659+
throw new IllegalArgumentException(s"Field ${prop.getName} is not found among available fields, which are: ${dataType.dt.fields.map(_.name).mkString(", ")}")
660+
val fieldName = maybeField.get.name
660661
val propClass = maybeField.map(it => it.dataType.asInstanceOf[DataTypeWithClass].cls).get
661662
val propDt = maybeField.map(it => it.dataType.asInstanceOf[DataTypeWithClass]).get
662663

core/src/main/scala/org/apache/spark/sql/KotlinWrappers.scala

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,6 +178,36 @@ case class KSimpleTypeWrapper(dt: DataType, cls: Class[_], nullable: Boolean) ex
178178
override private[spark] def asNullable = dt.asNullable
179179
}
180180

181+
class KStructField(val getterName: String, val delegate: StructField) extends StructField {
182+
override private[sql] def buildFormattedString(prefix: String, stringConcat: StringUtils.StringConcat, maxDepth: Int): Unit = delegate.buildFormattedString(prefix, stringConcat, maxDepth)
183+
184+
override def toString(): String = delegate.toString()
185+
186+
override private[sql] def jsonValue = delegate.jsonValue
187+
188+
override def withComment(comment: String): StructField = delegate.withComment(comment)
189+
190+
override def getComment(): Option[String] = delegate.getComment()
191+
192+
override def toDDL: String = delegate.toDDL
193+
194+
override def productElement(n: Int): Any = delegate.productElement(n)
195+
196+
override def productArity: Int = delegate.productArity
197+
198+
override def productIterator: Iterator[Any] = delegate.productIterator
199+
200+
override def productPrefix: String = delegate.productPrefix
201+
202+
override val dataType: DataType = delegate.dataType
203+
204+
override def canEqual(that: Any): Boolean = delegate.canEqual(that)
205+
206+
override val metadata: Metadata = delegate.metadata
207+
override val name: String = delegate.name
208+
override val nullable: Boolean = delegate.nullable
209+
}
210+
181211
object helpme {
182212

183213
def listToSeq(i: java.util.List[_]): Seq[_] = Seq(i.toArray: _*)

kotlin-spark-api/src/main/kotlin/org/jetbrains/spark/api/ApiV1.kt

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,14 @@
2121

2222
package org.jetbrains.spark.api
2323

24-
import org.apache.spark.SparkContext
2524
import org.apache.spark.api.java.function.*
2625
import org.apache.spark.sql.*
2726
import org.apache.spark.sql.Encoders.*
2827
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
2928
import org.apache.spark.sql.types.*
3029
import org.jetbrains.spark.extensions.KSparkExtensions
3130
import scala.reflect.ClassTag
31+
import java.beans.PropertyDescriptor
3232
import java.math.BigDecimal
3333
import java.sql.Date
3434
import java.sql.Timestamp
@@ -281,20 +281,20 @@ fun schema(type: KType, map: Map<String, KType> = mapOf()): DataType {
281281
mapValueParam.isMarkedNullable
282282
)
283283
}
284-
else -> KDataTypeWrapper(
285-
StructType(
286-
klass
287-
.declaredMemberProperties
288-
.filter { it.findAnnotation<Transient>() == null }
289-
.map {
290-
val projectedType = types[it.returnType.toString()] ?: it.returnType
291-
StructField(it.name, schema(projectedType, types), projectedType.isMarkedNullable, Metadata.empty())
292-
}
293-
.toTypedArray()
294-
),
295-
klass.java,
296-
true
297-
)
284+
else -> {
285+
val structType = StructType(
286+
klass
287+
.declaredMemberProperties
288+
.filter { it.findAnnotation<Transient>() == null }
289+
.map {
290+
val projectedType = types[it.returnType.toString()] ?: it.returnType
291+
val propertyDescriptor = PropertyDescriptor(it.name, klass.java, "is" + it.name.capitalize(), null)
292+
KStructField(propertyDescriptor.readMethod.name, StructField(it.name, schema(projectedType, types), projectedType.isMarkedNullable, Metadata.empty()))
293+
}
294+
.toTypedArray()
295+
)
296+
KDataTypeWrapper(structType, klass.java, true)
297+
}
298298
}
299299
}
300300

0 commit comments

Comments
 (0)