diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt index 4dada877d7..60fb38b76a 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/impl/api/toDataFrame.kt @@ -25,9 +25,12 @@ import org.jetbrains.kotlin.fir.symbols.SymbolInternals import org.jetbrains.kotlin.fir.symbols.impl.ConeClassLikeLookupTagImpl import org.jetbrains.kotlin.fir.symbols.impl.FirPropertySymbol import org.jetbrains.kotlin.fir.types.ConeClassLikeType +import org.jetbrains.kotlin.fir.types.ConeFlexibleType import org.jetbrains.kotlin.fir.types.ConeKotlinType +import org.jetbrains.kotlin.fir.types.ConeNullability import org.jetbrains.kotlin.fir.types.ConeStarProjection import org.jetbrains.kotlin.fir.types.ConeTypeParameterType +import org.jetbrains.kotlin.fir.types.ConeTypeProjection import org.jetbrains.kotlin.fir.types.canBeNull import org.jetbrains.kotlin.fir.types.classId import org.jetbrains.kotlin.fir.types.coneType @@ -41,8 +44,10 @@ import org.jetbrains.kotlin.fir.types.resolvedType import org.jetbrains.kotlin.fir.types.toRegularClassSymbol import org.jetbrains.kotlin.fir.types.toSymbol import org.jetbrains.kotlin.fir.types.type +import org.jetbrains.kotlin.fir.types.typeContext import org.jetbrains.kotlin.fir.types.upperBoundIfFlexible import org.jetbrains.kotlin.fir.types.withArguments +import org.jetbrains.kotlin.fir.types.withNullability import org.jetbrains.kotlin.name.ClassId import org.jetbrains.kotlin.name.FqName import org.jetbrains.kotlin.name.Name @@ -50,6 +55,7 @@ import org.jetbrains.kotlin.name.StandardClassIds import org.jetbrains.kotlin.name.StandardClassIds.List import org.jetbrains.kotlinx.dataframe.codeGen.* import org.jetbrains.kotlinx.dataframe.plugin.extensions.KotlinTypeFacade +import org.jetbrains.kotlinx.dataframe.plugin.extensions.wrap import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractInterpreter import org.jetbrains.kotlinx.dataframe.plugin.impl.AbstractSchemaModificationInterpreter import org.jetbrains.kotlinx.dataframe.plugin.impl.Arguments @@ -71,9 +77,11 @@ import java.util.* class ToDataFrameDsl : AbstractSchemaModificationInterpreter() { val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id) val Arguments.body by dsl() + val Arguments.typeArg0: ConeTypeProjection? by arg(lens = Interpreter.Id) + override fun Arguments.interpret(): PluginDataFrameSchema { val dsl = CreateDataFrameDslImplApproximation() - body(dsl, mapOf("explicitReceiver" to Interpreter.Success(receiver))) + body(dsl, mapOf("typeArg0" to Interpreter.Success(typeArg0))) return PluginDataFrameSchema(dsl.columns) } } @@ -81,17 +89,19 @@ class ToDataFrameDsl : AbstractSchemaModificationInterpreter() { class ToDataFrame : AbstractSchemaModificationInterpreter() { val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id) val Arguments.maxDepth: Number by arg(defaultValue = Present(DEFAULT_MAX_DEPTH)) + val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id) override fun Arguments.interpret(): PluginDataFrameSchema { - return toDataFrame(maxDepth.toInt(), receiver, TraverseConfiguration()) + return toDataFrame(maxDepth.toInt(), typeArg0, TraverseConfiguration()) } } class ToDataFrameDefault : AbstractSchemaModificationInterpreter() { val Arguments.receiver: FirExpression? by arg(lens = Interpreter.Id) + val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id) override fun Arguments.interpret(): PluginDataFrameSchema { - return toDataFrame(DEFAULT_MAX_DEPTH, receiver, TraverseConfiguration()) + return toDataFrame(DEFAULT_MAX_DEPTH, typeArg0, TraverseConfiguration()) } } @@ -109,14 +119,14 @@ private const val DEFAULT_MAX_DEPTH = 0 class Properties0 : AbstractInterpreter() { val Arguments.dsl: CreateDataFrameDslImplApproximation by arg() - val Arguments.explicitReceiver: FirExpression? by arg() val Arguments.maxDepth: Int by arg() val Arguments.body by dsl() + val Arguments.typeArg0: ConeTypeProjection by arg(lens = Interpreter.Id) override fun Arguments.interpret() { dsl.configuration.maxDepth = maxDepth body(dsl.configuration.traverseConfiguration, emptyMap()) - val schema = toDataFrame(dsl.configuration.maxDepth, explicitReceiver, dsl.configuration.traverseConfiguration) + val schema = toDataFrame(dsl.configuration.maxDepth, typeArg0, dsl.configuration.traverseConfiguration) dsl.columns.addAll(schema.columns()) } } @@ -172,8 +182,8 @@ class Exclude1 : AbstractInterpreter() { @OptIn(SymbolInternals::class) internal fun KotlinTypeFacade.toDataFrame( maxDepth: Int, - explicitReceiver: FirExpression?, - traverseConfiguration: TraverseConfiguration + arg: ConeTypeProjection, + traverseConfiguration: TraverseConfiguration, ): PluginDataFrameSchema { fun ConeKotlinType.isValueType() = this.isArrayTypeOrNullableArrayType || @@ -197,7 +207,7 @@ internal fun KotlinTypeFacade.toDataFrame( val preserveClasses = traverseConfiguration.preserveClasses.mapNotNullTo(mutableSetOf()) { it.classId } val preserveProperties = traverseConfiguration.preserveProperties.mapNotNullTo(mutableSetOf()) { it.calleeReference.toResolvedPropertySymbol() } - fun convert(classLike: ConeKotlinType, depth: Int): List { + fun convert(classLike: ConeKotlinType, depth: Int, makeNullable: Boolean): List { val symbol = classLike.toRegularClassSymbol(session) ?: return emptyList() val scope = symbol.unsubstitutedScope(session, ScopeSession(), false, FirResolvePhase.STATUS) val declarations = if (symbol.fir is FirJavaClass) { @@ -260,7 +270,7 @@ internal fun KotlinTypeFacade.toDataFrame( val keepSubtree = depth >= maxDepth && !fieldKind.shouldBeConvertedToColumnGroup && !fieldKind.shouldBeConvertedToFrameColumn if (keepSubtree || returnType.isValueType() || returnType.classId in preserveClasses || it in preserveProperties) { - SimpleDataColumn(name, TypeApproximation(returnType)) + SimpleDataColumn(name, TypeApproximation(returnType.withNullability(ConeNullability.create(makeNullable), session.typeContext))) } else if ( returnType.isSubtypeOf(StandardClassIds.Iterable.constructClassLikeType(arrayOf(ConeStarProjection)), session) || returnType.isSubtypeOf(StandardClassIds.Iterable.constructClassLikeType(arrayOf(ConeStarProjection), isNullable = true), session) @@ -271,30 +281,28 @@ internal fun KotlinTypeFacade.toDataFrame( else -> session.builtinTypes.nullableAnyType.type } if (type.isValueType()) { - SimpleDataColumn(name, - TypeApproximation( - List.constructClassLikeType( - arrayOf(type), - returnType.isNullable - ) - ) - ) + val columnType = List.constructClassLikeType(arrayOf(type), returnType.isNullable) + .withNullability(ConeNullability.create(makeNullable), session.typeContext) + .wrap() + SimpleDataColumn(name, columnType) } else { - SimpleFrameColumn(name, convert(type, depth + 1)) + SimpleFrameColumn(name, convert(type, depth + 1, makeNullable = false)) } } else { - SimpleColumnGroup(name, convert(returnType, depth + 1)) + SimpleColumnGroup(name, convert(returnType, depth + 1, returnType.isNullable || makeNullable)) } } } - val receiver = explicitReceiver ?: return PluginDataFrameSchema.EMPTY - val arg = receiver.resolvedType.typeArguments.firstOrNull() ?: return PluginDataFrameSchema.EMPTY return when { arg.isStarProjection -> PluginDataFrameSchema.EMPTY else -> { - val classLike = arg.type as? ConeClassLikeType ?: return PluginDataFrameSchema.EMPTY - val columns = convert(classLike, 0) + val classLike = when (val type = arg.type) { + is ConeClassLikeType -> type + is ConeFlexibleType -> type.upperBound + else -> null + } ?: return PluginDataFrameSchema.EMPTY + val columns = convert(classLike, 0, makeNullable = classLike.isNullable) PluginDataFrameSchema(columns) } } diff --git a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt index 30de23defb..fe135e62a5 100644 --- a/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt +++ b/plugins/kotlin-dataframe/src/org/jetbrains/kotlinx/dataframe/plugin/interpret.kt @@ -90,8 +90,8 @@ fun KotlinTypeFacade.interpret( val refinedArguments: RefinedArguments = functionCall.collectArgumentExpressions() val defaultArguments = processor.expectedArguments.filter { it.defaultValue is Present }.map { it.name }.toSet() - val actualArgsMap = refinedArguments.associateBy { it.name.identifier }.toSortedMap() - val conflictingKeys = additionalArguments.keys intersect actualArgsMap.keys + val actualValueArguments = refinedArguments.associateBy { it.name.identifier }.toSortedMap() + val conflictingKeys = additionalArguments.keys intersect actualValueArguments.keys if (conflictingKeys.isNotEmpty()) { if (isTest) { interpretationFrameworkError("Conflicting keys: $conflictingKeys") @@ -99,20 +99,34 @@ fun KotlinTypeFacade.interpret( return null } val expectedArgsMap = processor.expectedArguments - .filterNot { it.name.startsWith("typeArg") } .associateBy { it.name }.toSortedMap().minus(additionalArguments.keys) - val unexpectedArguments = expectedArgsMap.keys - defaultArguments != actualArgsMap.keys - defaultArguments + val typeArguments = buildMap { + functionCall.typeArguments.forEachIndexed { index, firTypeProjection -> + val key = "typeArg$index" + val lens = expectedArgsMap[key]?.lens ?: return@forEachIndexed + val value: Any = if (lens == Interpreter.Id) { + firTypeProjection.toConeTypeProjection() + } else { + val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type + if (type is ConeIntersectionType) return@forEachIndexed + Marker(type) + } + put(key, Interpreter.Success(value)) + } + } + + val unexpectedArguments = (expectedArgsMap.keys - defaultArguments) != (actualValueArguments.keys + typeArguments.keys - defaultArguments) if (unexpectedArguments) { if (isTest) { val message = buildString { appendLine("ERROR: Different set of arguments") appendLine("Implementation class: $processor") - appendLine("Not found in actual: ${expectedArgsMap.keys - actualArgsMap.keys}") - val diff = actualArgsMap.keys - expectedArgsMap.keys + appendLine("Not found in actual: ${expectedArgsMap.keys - actualValueArguments.keys}") + val diff = actualValueArguments.keys - expectedArgsMap.keys appendLine("Passed, but not expected: ${diff}") appendLine("add arguments to an interpeter:") - appendLine(diff.map { actualArgsMap[it] }) + appendLine(diff.map { actualValueArguments[it] }) } interpretationFrameworkError(message) } @@ -121,6 +135,7 @@ fun KotlinTypeFacade.interpret( val arguments = mutableMapOf>() arguments += additionalArguments + arguments += typeArguments val interpretationResults = refinedArguments.refinedArguments.mapNotNull { val name = it.name.identifier val expectedArgument = expectedArgsMap[name] ?: error("$processor $name") @@ -269,17 +284,6 @@ fun KotlinTypeFacade.interpret( value?.let { value1 -> it.name.identifier to value1 } } - functionCall.typeArguments.forEachIndexed { index, firTypeProjection -> - val type = firTypeProjection.toConeTypeProjection().type ?: session.builtinTypes.nullableAnyType.type - if (type is ConeIntersectionType) return@forEachIndexed -// val approximation = TypeApproximationImpl( -// type.classId!!.asFqNameString(), -// type.isMarkedNullable -// ) - val approximation = Marker(type) - arguments["typeArg$index"] = Interpreter.Success(approximation) - } - return if (interpretationResults.size == refinedArguments.refinedArguments.size) { arguments.putAll(interpretationResults) when (val res = processor.interpret(arguments, this)) { diff --git a/plugins/kotlin-dataframe/testData/box/toDataFrame_customIterable.kt b/plugins/kotlin-dataframe/testData/box/toDataFrame_customIterable.kt new file mode 100644 index 0000000000..397b514988 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/toDataFrame_customIterable.kt @@ -0,0 +1,29 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +@DataSchema +data class D( + val s: String +) + +class Subtree( + val p: Int, + val l: List, + val ld: List, +) + +class Root(val a: Subtree) + +class MyList(val l: List): List by l + +fun box(): String { + val l = listOf( + Root(Subtree(123, listOf(1), listOf(D("ff")))), + null + ) + val df = MyList(l).toDataFrame(maxDepth = 2) + df.compareSchemas(strict = true) + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableList.kt b/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableList.kt new file mode 100644 index 0000000000..fe197b26b4 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableList.kt @@ -0,0 +1,16 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +@DataSchema +data class D( + val s: String +) + +fun box(): String { + val df1 = listOf(D("bb"), null).toDataFrame() + df1.schema().print() + df1.compileTimeSchema().print() + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableListSubtree.kt b/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableListSubtree.kt new file mode 100644 index 0000000000..2666765cf5 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableListSubtree.kt @@ -0,0 +1,27 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +@DataSchema +data class D( + val s: String +) + +class Subtree( + val p: Int, + val l: List, + val ld: List, +) + +class Root(val a: Subtree) + +fun box(): String { + val l = listOf( + Root(Subtree(123, listOf(1), listOf(D("ff")))), + null + ) + val df = l.toDataFrame(maxDepth = 2) + df.compareSchemas(strict = true) + return "OK" +} diff --git a/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableSubtree.kt b/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableSubtree.kt new file mode 100644 index 0000000000..eda97d9298 --- /dev/null +++ b/plugins/kotlin-dataframe/testData/box/toDataFrame_nullableSubtree.kt @@ -0,0 +1,27 @@ +import org.jetbrains.kotlinx.dataframe.* +import org.jetbrains.kotlinx.dataframe.annotations.* +import org.jetbrains.kotlinx.dataframe.api.* +import org.jetbrains.kotlinx.dataframe.io.* + +@DataSchema +data class D( + val s: String +) + +class Subtree( + val p: Int, + val l: List, + val ld: List, +) + +class Root(val a: Subtree?) + +fun box(): String { + val l = listOf( + Root(Subtree(123, listOf(1), listOf(D("ff")))), + Root(null) + ) + val df = l.toDataFrame(maxDepth = 2) + df.compareSchemas(strict = true) + return "OK" +} diff --git a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java index 9d35da9f28..b2f59c4b6a 100644 --- a/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java +++ b/plugins/kotlin-dataframe/tests-gen/org/jetbrains/kotlin/fir/dataframe/DataFrameBlackBoxCodegenTestGenerated.java @@ -418,6 +418,12 @@ public void testToDataFrame_column() { runTest("testData/box/toDataFrame_column.kt"); } + @Test + @TestMetadata("toDataFrame_customIterable.kt") + public void testToDataFrame_customIterable() { + runTest("testData/box/toDataFrame_customIterable.kt"); + } + @Test @TestMetadata("toDataFrame_dataSchema.kt") public void testToDataFrame_dataSchema() { @@ -436,6 +442,24 @@ public void testToDataFrame_from() { runTest("testData/box/toDataFrame_from.kt"); } + @Test + @TestMetadata("toDataFrame_nullableList.kt") + public void testToDataFrame_nullableList() { + runTest("testData/box/toDataFrame_nullableList.kt"); + } + + @Test + @TestMetadata("toDataFrame_nullableListSubtree.kt") + public void testToDataFrame_nullableListSubtree() { + runTest("testData/box/toDataFrame_nullableListSubtree.kt"); + } + + @Test + @TestMetadata("toDataFrame_nullableSubtree.kt") + public void testToDataFrame_nullableSubtree() { + runTest("testData/box/toDataFrame_nullableSubtree.kt"); + } + @Test @TestMetadata("toDataFrame_superType.kt") public void testToDataFrame_superType() {