From 578d5c82ecd534f4e78ea9a88eb9198c2dba9b3b Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Sun, 16 Mar 2025 15:58:39 +0100 Subject: [PATCH 1/5] starting sum statistic rework --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 22 ++- .../jetbrains/kotlinx/dataframe/math/sum.kt | 151 ++++++------------ 2 files changed, 67 insertions(+), 106 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index c0bda09485..a62494f0b3 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -20,14 +20,15 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf +import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns +import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.zero import org.jetbrains.kotlinx.dataframe.math.sum import org.jetbrains.kotlinx.dataframe.math.sumOf import kotlin.reflect.KProperty -import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.typeOf // region DataColumn @@ -46,13 +47,18 @@ public inline fun DataColumn.sumOf(crossinline expres // region DataRow public fun AnyRow.rowSum(): Number = - Aggregators.sum.aggregateCalculatingType( - values = values().filterIsInstance(), - valueTypes = columnTypes().filter { it.isSubtypeOf(typeOf()) }.toSet(), - ) ?: 0 - -public inline fun AnyRow.rowSumOf(): T = values().filterIsInstance().sum(typeOf()) - + Aggregators.sum.aggregateOfRow(this) { + colsOf { it.isPrimitiveNumber() } + } ?: 0.0 + +public inline fun AnyRow.rowSumOf(): Number /*todo*/ { + require(typeOf() in primitiveNumberTypes) { + "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." + } + return Aggregators.sum + .aggregateOfRow(this) { colsOf() } + ?: 0.0 +} // endregion // region DataFrame diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt index 07de30db44..9e578aec16 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt @@ -1,139 +1,94 @@ package org.jetbrains.kotlinx.dataframe.math import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.CalculateReturnTypeOrNull -import java.math.BigDecimal -import java.math.BigInteger +import org.jetbrains.kotlinx.dataframe.impl.nothingType +import org.jetbrains.kotlinx.dataframe.impl.renderType import kotlin.reflect.KType import kotlin.reflect.full.withNullability +import kotlin.reflect.typeOf +import kotlin.sequences.filterNotNull @PublishedApi -internal fun Iterable.sumOf(type: KType, selector: (T) -> R?): R { +internal fun Iterable.sumOf(type: KType, selector: (T) -> R?): Number = + asSequence().sumOf(type, selector) + +@Suppress("UNCHECKED_CAST") +@PublishedApi +internal fun Sequence.sumOf(type: KType, selector: (T) -> R?): Number { if (type.isMarkedNullable) { - val seq = asSequence().mapNotNull(selector).asIterable() - return seq.sum(type) + return filterNotNull().sumOf(type.withNullability(false), selector) } - return when (type.classifier) { - Double::class -> sumOf(selector as ((T) -> Double)) as R - - // careful, conversion to Double to Float occurs! TODO, Issue #558 - Float::class -> sumOf { (selector as ((T) -> Float))(it).toDouble() }.toFloat() as R + return when (type.withNullability(false)) { + typeOf() -> sumOf(selector as (T) -> Double) - Int::class -> sumOf(selector as ((T) -> Int)) as R + typeOf() -> map(selector as (T) -> Float).sum() - // careful, conversion to Int occurs! TODO, Issue #558 - Short::class -> sumOf { (selector as ((T) -> Short))(it).toInt() }.toShort() as R + typeOf() -> sumOf(selector as (T) -> Int) - // careful, conversion to Int occurs! TODO, Issue #558 - Byte::class -> sumOf { (selector as ((T) -> Byte))(it).toInt() }.toByte() as R + // Note: returns Int + typeOf() -> map(selector as (T) -> Short).sum() - Long::class -> sumOf(selector as ((T) -> Long)) as R + // Note: returns Int + typeOf() -> map(selector as (T) -> Byte).sum() - BigDecimal::class -> sumOf(selector as ((T) -> BigDecimal)) as R + typeOf() -> sumOf(selector as (T) -> Long) - BigInteger::class -> sumOf(selector as ((T) -> BigInteger)) as R + nothingType -> 0.0 - Number::class -> sumOf { (selector as ((T) -> Number))(it).toDouble() } as R - - Nothing::class -> 0.0 as R + typeOf() -> + error("Encountered non-specific Number type in sumOf function. This should not occur.") else -> throw IllegalArgumentException("sumOf is not supported for $type") } } -@PublishedApi -internal fun Iterable.sum(type: KType): T = - when (type.classifier) { - Double::class -> (this as Iterable).sum() as T - - Float::class -> (this as Iterable).sum() as T - - Int::class -> (this as Iterable).sum() as T - - // TODO result should be Int, but same type as input is returned, Issue #558 - Short::class -> (this as Iterable).sum().toShort() as T - - // TODO result should be Int, but same type as input is returned, Issue #558 - Byte::class -> (this as Iterable).sum().toByte() as T - - Long::class -> (this as Iterable).sum() as T - - BigDecimal::class -> (this as Iterable).sum() as T - - BigInteger::class -> (this as Iterable).sum() as T - - Number::class -> (this as Iterable).map { it.toDouble() }.sum() as T - - Nothing::class -> 0.0 as T - - else -> throw IllegalArgumentException("sum is not supported for $type") - } +internal fun Iterable.sum(type: KType): Number = asSequence().sum(type) +@Suppress("UNCHECKED_CAST") @JvmName("sumNullableT") @PublishedApi -internal fun Iterable.sum(type: KType): T = - when (type.classifier) { - Double::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - Float::class -> (this as Iterable).asSequence().filterNotNull().sum() as T - - Int::class -> (this as Iterable).asSequence().filterNotNull().sum() as T +internal fun Sequence.sum(type: KType): Number { + if (type.isMarkedNullable) { + return filterNotNull().sum(type.withNullability(false)) + } + return when (type.withNullability(false)) { + typeOf() -> (this as Sequence).sum() - // TODO result should be Int, but same type as input is returned, Issue #558 - Short::class -> (this as Iterable).asSequence().filterNotNull().sum().toShort() as T + typeOf() -> (this as Sequence).sum() - // TODO result should be Int, but same type as input is returned, Issue #558 - Byte::class -> (this as Iterable).asSequence().filterNotNull().sum().toByte() as T + typeOf() -> (this as Sequence).sum() - Long::class -> (this as Iterable).asSequence().filterNotNull().sum() as T + // Note: returns Int + typeOf() -> (this as Sequence).sum() - BigDecimal::class -> (this as Iterable).asSequence().filterNotNull().sum() as T + // Note: returns Int + typeOf() -> (this as Sequence).sum() - BigInteger::class -> (this as Iterable).asSequence().filterNotNull().sum() as T + typeOf() -> (this as Sequence).sum() - Number::class -> (this as Iterable).asSequence().filterNotNull().map { it.toDouble() }.sum() as T + typeOf() -> + error("Encountered non-specific Number type in sum function. This should not occur.") - Nothing::class -> 0.0 as T + nothingType -> 0.0 - else -> throw IllegalArgumentException("sum is not supported for $type") + else -> throw IllegalArgumentException( + "Unable to compute the sum for ${renderType(type)}, Only primitive numbers are supported.", + ) } +} /** T: Number? -> T */ internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ -> - type.withNullability(false) -} - -@PublishedApi -internal fun Iterable.sum(): BigDecimal { - var sum: BigDecimal = BigDecimal.ZERO - for (element in this) { - sum += element - } - return sum -} + when (type.withNullability(false)) { + typeOf(), + typeOf(), + typeOf(), + -> typeOf() -@PublishedApi -internal fun Sequence.sum(): BigDecimal { - var sum: BigDecimal = BigDecimal.ZERO - for (element in this) { - sum += element - } - return sum -} + typeOf() -> typeOf() -@PublishedApi -internal fun Iterable.sum(): BigInteger { - var sum: BigInteger = BigInteger.ZERO - for (element in this) { - sum += element - } - return sum -} + typeOf() -> typeOf() -@PublishedApi -internal fun Sequence.sum(): BigInteger { - var sum: BigInteger = BigInteger.ZERO - for (element in this) { - sum += element + else -> typeOf() } - return sum } From 13150c445838295949c2500eab232b443b26625d Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 16:35:40 +0100 Subject: [PATCH 2/5] adding overloads for DataColumn.sum --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 28 +++++++++++++------ .../jetbrains/kotlinx/dataframe/math/sum.kt | 2 ++ 2 files changed, 22 insertions(+), 8 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index a62494f0b3..02ad242983 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -13,7 +13,6 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf -import org.jetbrains.kotlinx.dataframe.columns.values import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast @@ -26,18 +25,32 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.zero -import org.jetbrains.kotlinx.dataframe.math.sum import org.jetbrains.kotlinx.dataframe.math.sumOf import kotlin.reflect.KProperty import kotlin.reflect.typeOf // region DataColumn -@JvmName("sumT") -public fun DataColumn.sum(): T = values.sum(type()) +@JvmName("sumInt") +public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int -@JvmName("sumTNullable") -public fun DataColumn.sum(): T = values.sum(type()) +@JvmName("sumShort") +public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int + +@JvmName("sumByte") +public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int + +@JvmName("sumLong") +public fun DataColumn.sum(): Long = Aggregators.sum.aggregate(this) as Long + +@JvmName("sumFloat") +public fun DataColumn.sum(): Float = Aggregators.sum.aggregate(this) as Float + +@JvmName("sumDouble") +public fun DataColumn.sum(): Double = Aggregators.sum.aggregate(this) as Double + +@JvmName("sumNumber") +public fun DataColumn.sum(): Number = Aggregators.sum.aggregate(this) public inline fun DataColumn.sumOf(crossinline expression: (T) -> R): R? = (Aggregators.sum as Aggregator<*, *>).cast().of(this, expression) @@ -49,7 +62,7 @@ public inline fun DataColumn.sumOf(crossinline expres public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this) { colsOf { it.isPrimitiveNumber() } - } ?: 0.0 + } public inline fun AnyRow.rowSumOf(): Number /*todo*/ { require(typeOf() in primitiveNumberTypes) { @@ -57,7 +70,6 @@ public inline fun AnyRow.rowSumOf(): Number /*todo*/ { } return Aggregators.sum .aggregateOfRow(this) { colsOf() } - ?: 0.0 } // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt index 9e578aec16..6361c32c67 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt @@ -85,6 +85,8 @@ internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ -> typeOf(), -> typeOf() + typeOf() -> typeOf() + typeOf() -> typeOf() typeOf() -> typeOf() From b046891ccc43c0992d515a717b501576ad7082e5 Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Mon, 17 Mar 2025 17:58:44 +0100 Subject: [PATCH 3/5] update from mean branch, added sumOf --- .../jetbrains/kotlinx/dataframe/api/sum.kt | 40 ++++++++++++++++--- .../impl/aggregation/modes/forEveryColumn.kt | 2 +- .../aggregation/modes/withinAllColumns.kt | 10 ++--- .../jetbrains/kotlinx/dataframe/math/sum.kt | 34 ---------------- 4 files changed, 40 insertions(+), 46 deletions(-) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index dbf6be6075..a9ca93f356 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -1,3 +1,5 @@ +@file:OptIn(ExperimentalTypeInference::class) + package org.jetbrains.kotlinx.dataframe.api import org.jetbrains.kotlinx.dataframe.AnyRow @@ -13,19 +15,16 @@ import org.jetbrains.kotlinx.dataframe.annotations.Refine import org.jetbrains.kotlinx.dataframe.columns.ColumnReference import org.jetbrains.kotlinx.dataframe.columns.toColumnSet import org.jetbrains.kotlinx.dataframe.columns.toColumnsSetOf -import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregators -import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.cast import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow -import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.of import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.zero -import org.jetbrains.kotlinx.dataframe.math.sumOf +import kotlin.experimental.ExperimentalTypeInference import kotlin.reflect.KProperty import kotlin.reflect.typeOf @@ -52,8 +51,37 @@ public fun DataColumn.sum(): Double = Aggregators.sum.aggregate(this) a @JvmName("sumNumber") public fun DataColumn.sum(): Number = Aggregators.sum.aggregate(this) -public inline fun DataColumn.sumOf(noinline expression: (T) -> R): R? = - (Aggregators.sum as Aggregator<*, *>).cast().aggregateOf(this, expression) +@JvmName("sumOfInt") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Int?): Int = Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfShort") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Short?): Int = + Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfByte") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfLong") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Long?): Long = + Aggregators.sum.aggregateOf(this, expression) as Long + +@JvmName("sumOfFloat") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Float?): Float = + Aggregators.sum.aggregateOf(this, expression) as Float + +@JvmName("sumOfDouble") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Double?): Double = + Aggregators.sum.aggregateOf(this, expression) as Double + +@JvmName("sumOfNumber") +@OverloadResolutionByLambdaReturnType +public fun DataColumn.sumOf(expression: (T) -> Number?): Number = Aggregators.sum.aggregateOf(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt index 6ec6459ff0..70e55e2f8b 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/forEveryColumn.kt @@ -41,7 +41,7 @@ internal fun Aggregator<*, R>.aggregateFor( } internal fun AggregateInternalDsl.aggregateFor( - columns: ColumnsForAggregateSelector, + columns: ColumnsForAggregateSelector, aggregator: Aggregator, ) { val cols = df.getAggregateColumns(columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt index b7217d3b1e..7467f2da3f 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/withinAllColumns.kt @@ -15,18 +15,18 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.internal import org.jetbrains.kotlinx.dataframe.impl.emptyPath @PublishedApi -internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R = +internal fun Aggregator<*, R>.aggregateAll(data: DataFrame, columns: ColumnsSelector): R = data.aggregateAll(cast2(), columns) internal fun Aggregator<*, R>.aggregateAll( data: Grouped, name: String?, - columns: ColumnsSelector, + columns: ColumnsSelector, ): DataFrame = data.aggregateAll(cast(), columns, name) internal fun Aggregator<*, R>.aggregateAll( data: PivotGroupBy, - columns: ColumnsSelector, + columns: ColumnsSelector, ): DataFrame = data.aggregateAll(cast(), columns) internal fun DataFrame.aggregateAll(aggregator: Aggregator, columns: ColumnsSelector): R = @@ -34,7 +34,7 @@ internal fun DataFrame.aggregateAll(aggregator: Aggregator, c internal fun Grouped.aggregateAll( aggregator: Aggregator, - columns: ColumnsSelector, + columns: ColumnsSelector, name: String?, ): DataFrame = aggregateInternal { @@ -48,7 +48,7 @@ internal fun Grouped.aggregateAll( internal fun PivotGroupBy.aggregateAll( aggregator: Aggregator, - columns: ColumnsSelector, + columns: ColumnsSelector, ): DataFrame = aggregate { val cols = get(columns) diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt index 6361c32c67..3f6ccbe5dc 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt @@ -8,40 +8,6 @@ import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf import kotlin.sequences.filterNotNull -@PublishedApi -internal fun Iterable.sumOf(type: KType, selector: (T) -> R?): Number = - asSequence().sumOf(type, selector) - -@Suppress("UNCHECKED_CAST") -@PublishedApi -internal fun Sequence.sumOf(type: KType, selector: (T) -> R?): Number { - if (type.isMarkedNullable) { - return filterNotNull().sumOf(type.withNullability(false), selector) - } - return when (type.withNullability(false)) { - typeOf() -> sumOf(selector as (T) -> Double) - - typeOf() -> map(selector as (T) -> Float).sum() - - typeOf() -> sumOf(selector as (T) -> Int) - - // Note: returns Int - typeOf() -> map(selector as (T) -> Short).sum() - - // Note: returns Int - typeOf() -> map(selector as (T) -> Byte).sum() - - typeOf() -> sumOf(selector as (T) -> Long) - - nothingType -> 0.0 - - typeOf() -> - error("Encountered non-specific Number type in sumOf function. This should not occur.") - - else -> throw IllegalArgumentException("sumOf is not supported for $type") - } -} - internal fun Iterable.sum(type: KType): Number = asSequence().sum(type) @Suppress("UNCHECKED_CAST") From 9aaf84d75cd9734fce51510f94780ae09355005b Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Tue, 18 Mar 2025 13:43:22 +0100 Subject: [PATCH 4/5] extra overloads for sum api, fixed parts of aggregateOfRow --- core/api/core.api | 24 ++- .../jetbrains/kotlinx/dataframe/api/mean.kt | 12 +- .../jetbrains/kotlinx/dataframe/api/sum.kt | 158 +++++++++++------- .../dataframe/impl/aggregation/getColumns.kt | 6 + .../dataframe/impl/aggregation/modes/row.kt | 4 +- .../jetbrains/kotlinx/dataframe/math/sum.kt | 16 +- 6 files changed, 135 insertions(+), 85 deletions(-) diff --git a/core/api/core.api b/core/api/core.api index 0f002f0ab0..d60b81750f 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -3816,6 +3816,7 @@ public final class org/jetbrains/kotlinx/dataframe/api/StdKt { public final class org/jetbrains/kotlinx/dataframe/api/SumKt { public static final fun rowSum (Lorg/jetbrains/kotlinx/dataframe/DataRow;)Ljava/lang/Number; + public static final fun rowSumOf (Lorg/jetbrains/kotlinx/dataframe/DataRow;Lkotlin/reflect/KType;)Ljava/lang/Number; public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Ljava/lang/Number; public static final fun sum (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; @@ -3839,6 +3840,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt { public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Grouped;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;Ljava/lang/String;ILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/Pivot;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static synthetic fun sum$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I + public static final fun sumByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Ljava/lang/String;)Lorg/jetbrains/kotlinx/dataframe/DataRow; public static final fun sumFor (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)Lorg/jetbrains/kotlinx/dataframe/DataRow; @@ -3863,8 +3868,15 @@ public final class org/jetbrains/kotlinx/dataframe/api/SumKt { public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Ljava/lang/String;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lkotlin/reflect/KProperty;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; public static synthetic fun sumFor$default (Lorg/jetbrains/kotlinx/dataframe/api/PivotGroupBy;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;ZILjava/lang/Object;)Lorg/jetbrains/kotlinx/dataframe/DataFrame; - public static final fun sumT (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number; - public static final fun sumTNullable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number; + public static final fun sumNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Ljava/lang/Number; + public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I + public static final fun sumOfByte (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/jvm/functions/Function1;)I + public static final fun sumOfShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;Lkotlin/jvm/functions/Function2;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lkotlin/reflect/KProperty;)I + public static final fun sumShort (Lorg/jetbrains/kotlinx/dataframe/DataFrame;[Lorg/jetbrains/kotlinx/dataframe/columns/ColumnReference;)I } public final class org/jetbrains/kotlinx/dataframe/api/TailKt { @@ -6153,13 +6165,7 @@ public final class org/jetbrains/kotlinx/dataframe/math/StdKt { } public final class org/jetbrains/kotlinx/dataframe/math/SumKt { - public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigDecimal; - public static final fun sum (Ljava/lang/Iterable;)Ljava/math/BigInteger; - public static final fun sum (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number; - public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigDecimal; - public static final fun sum (Lkotlin/sequences/Sequence;)Ljava/math/BigInteger; - public static final fun sumNullableT (Ljava/lang/Iterable;Lkotlin/reflect/KType;)Ljava/lang/Number; - public static final fun sumOf (Ljava/lang/Iterable;Lkotlin/reflect/KType;Lkotlin/jvm/functions/Function1;)Ljava/lang/Number; + public static final fun sumNullableT (Lkotlin/sequences/Sequence;Lkotlin/reflect/KType;)Ljava/lang/Number; } public abstract class org/jetbrains/kotlinx/dataframe/schema/ColumnSchema { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index bbdb9707d0..bfb7f88edf 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -25,10 +25,9 @@ import kotlin.reflect.KProperty import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf -/* - * TODO KDocs: +/* TODO KDocs: * Calculating the mean is supported for all primitive number types. - * Nulls are filtered from columns. + * Nulls are filtered out. * The return type is always Double, Double.NaN for empty input, never null. * (May introduce loss of precision for Longs). * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the mean. @@ -48,16 +47,13 @@ public inline fun DataColumn.meanOf( // region DataRow public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = - Aggregators.mean(skipNA).aggregateOfRow(this) { - colsOf { it.isPrimitiveNumber() } - } + Aggregators.mean(skipNA).aggregateOfRow(this, primitiveNumberColumns()) public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { require(typeOf().withNullability(false) in primitiveNumberTypes) { "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." } - return Aggregators.mean(skipNA) - .aggregateOfRow(this) { colsOf() } + return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf() } } // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index a9ca93f356..599e5497e5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -1,4 +1,5 @@ @file:OptIn(ExperimentalTypeInference::class) +@file:Suppress("LocalVariableName") package org.jetbrains.kotlinx.dataframe.api @@ -20,18 +21,26 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow -import org.jetbrains.kotlinx.dataframe.impl.aggregation.numberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes -import org.jetbrains.kotlinx.dataframe.impl.zero import kotlin.experimental.ExperimentalTypeInference +import kotlin.reflect.KClass import kotlin.reflect.KProperty +import kotlin.reflect.KType +import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf -// region DataColumn +/* TODO KDocs + * Calculating the sum is supported for all primitive number types. + * Nulls are filtered out. + * The return type is always the same as the input type (never null), except for `Byte` and `Short`, + * which are converted to `Int`. + * Empty input will result in 0 in the supplied number type. + * For mixed primitive number types, [TwoStepNumbersAggregator] unifies the numbers before calculating the sum. + */ -@JvmName("sumInt") -public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int +// region DataColumn @JvmName("sumShort") public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int @@ -39,71 +48,66 @@ public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as In @JvmName("sumByte") public fun DataColumn.sum(): Int = Aggregators.sum.aggregate(this) as Int -@JvmName("sumLong") -public fun DataColumn.sum(): Long = Aggregators.sum.aggregate(this) as Long - -@JvmName("sumFloat") -public fun DataColumn.sum(): Float = Aggregators.sum.aggregate(this) as Float - -@JvmName("sumDouble") -public fun DataColumn.sum(): Double = Aggregators.sum.aggregate(this) as Double - +@Suppress("UNCHECKED_CAST") @JvmName("sumNumber") -public fun DataColumn.sum(): Number = Aggregators.sum.aggregate(this) - -@JvmName("sumOfInt") -@OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Int?): Int = Aggregators.sum.aggregateOf(this, expression) as Int +public fun DataColumn.sum(): T = Aggregators.sum.aggregate(this) as T @JvmName("sumOfShort") @OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Short?): Int = +public fun DataColumn.sumOf(expression: (C) -> Short?): Int = Aggregators.sum.aggregateOf(this, expression) as Int @JvmName("sumOfByte") @OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int - -@JvmName("sumOfLong") -@OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Long?): Long = - Aggregators.sum.aggregateOf(this, expression) as Long - -@JvmName("sumOfFloat") -@OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Float?): Float = - Aggregators.sum.aggregateOf(this, expression) as Float - -@JvmName("sumOfDouble") -@OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Double?): Double = - Aggregators.sum.aggregateOf(this, expression) as Double +public fun DataColumn.sumOf(expression: (C) -> Byte?): Int = Aggregators.sum.aggregateOf(this, expression) as Int @JvmName("sumOfNumber") @OverloadResolutionByLambdaReturnType -public fun DataColumn.sumOf(expression: (T) -> Number?): Number = Aggregators.sum.aggregateOf(this, expression) +public inline fun DataColumn.sumOf(crossinline expression: (C) -> V?): V = + Aggregators.sum.aggregateOf(this, expression) as V // endregion // region DataRow -public fun AnyRow.rowSum(): Number = - Aggregators.sum.aggregateOfRow(this) { - colsOf { it.isPrimitiveNumber() } - } +public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveNumberColumns()) -public inline fun AnyRow.rowSumOf(): Number /*todo*/ { - require(typeOf() in primitiveNumberTypes) { - "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." +@JvmName("rowSumOfShort") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Short::class): Int = + rowSumOf(typeOf()) as Int + +@JvmName("rowSumOfByte") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Byte::class): Int = + rowSumOf(typeOf()) as Int + +@JvmName("rowSumOfInt") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Int::class): Int = + rowSumOf(typeOf()) as Int + +@JvmName("rowSumOfLong") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Long::class): Long = + rowSumOf(typeOf()) as Long + +@JvmName("rowSumOfFloat") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Float::class): Float = + rowSumOf(typeOf()) as Float + +@JvmName("rowSumOfDouble") +public inline fun AnyRow.rowSumOf(_kClass: KClass = Double::class): Double = + rowSumOf(typeOf()) as Double + +// unfortunately, we cannot make a `reified T : Number?` due to clashes +public fun AnyRow.rowSumOf(type: KType): Number { + require(type.withNullability(false) in primitiveNumberTypes) { + "Type $type is not a primitive number type. Mean only supports primitive number types." } - return Aggregators.sum - .aggregateOfRow(this) { colsOf() } + return Aggregators.sum.aggregateOfRow(this) { colsOf(type) } } // endregion // region DataFrame -public fun DataFrame.sum(): DataRow = sumFor(numberColumns()) +public fun DataFrame.sum(): DataRow = sumFor(primitiveNumberColumns()) public fun DataFrame.sumFor(columns: ColumnsForAggregateSelector): DataRow = Aggregators.sum.aggregateFor(this, columns) @@ -118,28 +122,70 @@ public fun DataFrame.sumFor(vararg columns: ColumnReference DataFrame.sumFor(vararg columns: KProperty): DataRow = sumFor { columns.toColumnSet() } +@JvmName("sumShort") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sum(columns: ColumnsSelector): Int = + Aggregators.sum.aggregateAll(this, columns) as Int + +@JvmName("sumByte") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sum(columns: ColumnsSelector): Int = + Aggregators.sum.aggregateAll(this, columns) as Int + +@JvmName("sumNumber") +@OverloadResolutionByLambdaReturnType public inline fun DataFrame.sum(noinline columns: ColumnsSelector): C = - (Aggregators.sum.aggregateAll(this, columns) as C?) ?: C::class.zero() + Aggregators.sum.aggregateAll(this, columns) as C +@JvmName("sumShort") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: ColumnReference): Int = sum { columns.toColumnSet() } + +@JvmName("sumByte") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: ColumnReference): Int = sum { columns.toColumnSet() } + +@JvmName("sumNumber") @AccessApiOverload public inline fun DataFrame.sum(vararg columns: ColumnReference): C = sum { columns.toColumnSet() } -public fun DataFrame.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() } +public fun DataFrame.sum(vararg columns: String): Number = sum { columns.toColumnsSetOf() } + +@JvmName("sumShort") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: KProperty): Int = sum { columns.toColumnSet() } + +@JvmName("sumByte") +@AccessApiOverload +public fun DataFrame.sum(vararg columns: KProperty): Int = sum { columns.toColumnSet() } +@JvmName("sumNumber") @AccessApiOverload public inline fun DataFrame.sum(vararg columns: KProperty): C = sum { columns.toColumnSet() } -public inline fun DataFrame.sumOf(crossinline expression: RowExpression): C = - rows().sumOf(typeOf()) { expression(it, it) } +@JvmName("sumOfShort") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sumOf(expression: RowExpression): Int = + Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfByte") +@OverloadResolutionByLambdaReturnType +public fun DataFrame.sumOf(expression: RowExpression): Int = + Aggregators.sum.aggregateOf(this, expression) as Int + +@JvmName("sumOfNumber") +@OverloadResolutionByLambdaReturnType +public inline fun DataFrame.sumOf(crossinline expression: RowExpression): C = + Aggregators.sum.aggregateOf(this, expression) as C // endregion // region GroupBy @Refine @Interpretable("GroupBySum1") -public fun Grouped.sum(): DataFrame = sumFor(numberColumns()) +public fun Grouped.sum(): DataFrame = sumFor(primitiveNumberColumns()) @Refine @Interpretable("GroupBySum0") @@ -183,7 +229,7 @@ public inline fun Grouped.sumOf( // region Pivot -public fun Pivot.sum(separate: Boolean = false): DataRow = sumFor(separate, numberColumns()) +public fun Pivot.sum(separate: Boolean = false): DataRow = sumFor(separate, primitiveNumberColumns()) public fun Pivot.sumFor( separate: Boolean = false, @@ -213,14 +259,14 @@ public fun Pivot.sum(vararg columns: ColumnReference): Da @AccessApiOverload public fun Pivot.sum(vararg columns: KProperty): DataRow = sum { columns.toColumnSet() } -public inline fun Pivot.sumOf(crossinline expression: RowExpression): DataRow = +public inline fun Pivot.sumOf(crossinline expression: RowExpression): DataRow = delegate { sumOf(expression) } // endregion // region PivotGroupBy -public fun PivotGroupBy.sum(separate: Boolean = false): DataFrame = sumFor(separate, numberColumns()) +public fun PivotGroupBy.sum(separate: Boolean = false): DataFrame = sumFor(separate, primitiveNumberColumns()) public fun PivotGroupBy.sumFor( separate: Boolean = false, @@ -256,7 +302,7 @@ public fun PivotGroupBy.sum(vararg columns: KProperty): D sum { columns.toColumnSet() } public inline fun PivotGroupBy.sumOf( - crossinline expression: RowExpression, + crossinline expression: RowExpression, ): DataFrame = Aggregators.sum.aggregateOf(this, expression) // endregion diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt index 66f598b02d..503efd92c5 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt @@ -2,14 +2,17 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation import org.jetbrains.kotlinx.dataframe.AnyCol import org.jetbrains.kotlinx.dataframe.ColumnsSelector +import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.aggregation.Aggregatable import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue +import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.isNumber import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber internal inline fun Aggregatable.remainingColumns( crossinline predicate: (AnyCol) -> Boolean, @@ -27,6 +30,9 @@ internal fun Aggregatable.numberColumns(): ColumnsSelector = internal fun Aggregatable.primitiveNumberColumns(): ColumnsSelector = remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector +internal fun DataRow.primitiveNumberColumns(): ColumnsSelector = + { cols { it.isPrimitiveNumber() }.cast() } + internal fun NamedValue.toColumnWithPath() = path to createColumnGuessingType( name = path.last(), diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt index 53abbbe9b5..d6adab9304 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/modes/row.kt @@ -1,7 +1,7 @@ package org.jetbrains.kotlinx.dataframe.impl.aggregation.modes -import org.jetbrains.kotlinx.dataframe.AnyRow import org.jetbrains.kotlinx.dataframe.ColumnsSelector +import org.jetbrains.kotlinx.dataframe.DataRow import org.jetbrains.kotlinx.dataframe.api.getColumns import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator @@ -14,7 +14,7 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.aggregators.Aggregator * @param columns selector of which columns inside the [row] to aggregate */ @PublishedApi -internal fun Aggregator.aggregateOfRow(row: AnyRow, columns: ColumnsSelector<*, V?>): R { +internal fun Aggregator.aggregateOfRow(row: DataRow, columns: ColumnsSelector): R { val filteredColumns = row.df().getColumns(columns) return aggregateCalculatingType( values = filteredColumns.mapNotNull { row[it] }, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt index 3f6ccbe5dc..613c19ecb1 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt @@ -45,18 +45,14 @@ internal fun Sequence.sum(type: KType): Number { /** T: Number? -> T */ internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ -> - when (type.withNullability(false)) { - typeOf(), - typeOf(), - typeOf(), - -> typeOf() + when (val type = type.withNullability(false)) { + // type changes to Int + typeOf(), typeOf() -> typeOf() - typeOf() -> typeOf() - - typeOf() -> typeOf() - - typeOf() -> typeOf() + // type remains the same + typeOf(), typeOf(), typeOf(), typeOf() -> type + // defaults to Double else -> typeOf() } } From 7502a80c86ec4cf00c6f7b487c9dd62f02cad52b Mon Sep 17 00:00:00 2001 From: Jolan Rensen Date: Wed, 19 Mar 2025 15:04:03 +0100 Subject: [PATCH 5/5] made mean and sum use isPrimitiveOrMixedNumber(). Number unification can now return `null`, so we can give helpful errors from the aggregator --- core/api/core.api | 5 ++ .../kotlinx/dataframe/api/DataColumnType.kt | 18 ++++- .../jetbrains/kotlinx/dataframe/api/mean.kt | 17 ++-- .../jetbrains/kotlinx/dataframe/api/sum.kt | 18 ++--- .../kotlinx/dataframe/impl/NumberTypeUtils.kt | 78 +++++++++++++------ .../kotlinx/dataframe/impl/TypeUtils.kt | 2 +- .../aggregators/TwoStepNumbersAggregator.kt | 32 +++++--- .../dataframe/impl/aggregation/getColumns.kt | 11 ++- .../jetbrains/kotlinx/dataframe/math/sum.kt | 8 +- .../kotlinx/dataframe/statistics/sum.kt | 22 +++++- .../kotlinx/dataframe/types/UtilTests.kt | 46 +++++------ 11 files changed, 163 insertions(+), 94 deletions(-) diff --git a/core/api/core.api b/core/api/core.api index d60b81750f..f03122f4c3 100644 --- a/core/api/core.api +++ b/core/api/core.api @@ -1781,8 +1781,10 @@ public final class org/jetbrains/kotlinx/dataframe/api/DataColumnTypeKt { public static final fun isComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isFrameColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isList (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z + public static final fun isMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isPrimitiveNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z + public static final fun isPrimitiveOrMixedNumber (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun isSubtypeOf (Lorg/jetbrains/kotlinx/dataframe/DataColumn;Lkotlin/reflect/KType;)Z public static final fun isValueColumn (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z public static final fun valuesAreComparable (Lorg/jetbrains/kotlinx/dataframe/DataColumn;)Z @@ -5116,6 +5118,9 @@ public final class org/jetbrains/kotlinx/dataframe/impl/ExceptionUtilsKt { public final class org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtilsKt { public static final fun getPrimitiveNumberTypes ()Ljava/util/Set; + public static final fun isMixedNumber (Lkotlin/reflect/KType;)Z + public static final fun isPrimitiveNumber (Lkotlin/reflect/KType;)Z + public static final fun isPrimitiveOrMixedNumber (Lkotlin/reflect/KType;)Z } public final class org/jetbrains/kotlinx/dataframe/impl/TypeUtilsKt { diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt index 9800f33463..a69d8db723 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/DataColumnType.kt @@ -7,7 +7,9 @@ import org.jetbrains.kotlinx.dataframe.columns.ColumnGroup import org.jetbrains.kotlinx.dataframe.columns.ColumnKind import org.jetbrains.kotlinx.dataframe.columns.FrameColumn import org.jetbrains.kotlinx.dataframe.columns.ValueColumn -import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes +import org.jetbrains.kotlinx.dataframe.impl.isMixedNumber +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber import org.jetbrains.kotlinx.dataframe.type import org.jetbrains.kotlinx.dataframe.typeClass import org.jetbrains.kotlinx.dataframe.util.IS_COMPARABLE @@ -48,11 +50,23 @@ public inline fun AnyCol.isType(): Boolean = type() == typeOf() /** Returns `true` when this column's type is a subtype of `Number?` */ public fun AnyCol.isNumber(): Boolean = isSubtypeOf() +/** Returns `true` only when this column's type is exactly `Number` or `Number?`. */ +public fun AnyCol.isMixedNumber(): Boolean = type().isMixedNumber() + /** * Returns `true` when this column has the (nullable) type of either: * [Byte], [Short], [Int], [Long], [Float], or [Double]. */ -public fun AnyCol.isPrimitiveNumber(): Boolean = type().withNullability(false) in primitiveNumberTypes +public fun AnyCol.isPrimitiveNumber(): Boolean = type().isPrimitiveNumber() + +/** + * Returns `true` when this column has the (nullable) type of either: + * [Byte], [Short], [Int], [Long], [Float], [Double], or [Number]. + * + * Careful: Will return `true` if the column contains multiple number types that + * might NOT be primitive. + */ +public fun AnyCol.isPrimitiveOrMixedNumber(): Boolean = type().isPrimitiveOrMixedNumber() public fun AnyCol.isList(): Boolean = typeClass == List::class diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt index bfb7f88edf..3ebb607306 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/mean.kt @@ -18,11 +18,10 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow -import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns -import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber import kotlin.reflect.KProperty -import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf /* TODO KDocs: @@ -47,10 +46,10 @@ public inline fun DataColumn.meanOf( // region DataRow public fun AnyRow.rowMean(skipNA: Boolean = skipNA_default): Double = - Aggregators.mean(skipNA).aggregateOfRow(this, primitiveNumberColumns()) + Aggregators.mean(skipNA).aggregateOfRow(this, primitiveOrMixedNumberColumns()) public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipNA_default): Double { - require(typeOf().withNullability(false) in primitiveNumberTypes) { + require(typeOf().isPrimitiveOrMixedNumber()) { "Type ${T::class.simpleName} is not a primitive number type. Mean only supports primitive number types." } return Aggregators.mean(skipNA).aggregateOfRow(this) { colsOf() } @@ -61,7 +60,7 @@ public inline fun AnyRow.rowMeanOf(skipNA: Boolean = skipN // region DataFrame public fun DataFrame.mean(skipNA: Boolean = skipNA_default): DataRow = - meanFor(skipNA, primitiveNumberColumns()) + meanFor(skipNA, primitiveOrMixedNumberColumns()) public fun DataFrame.meanFor( skipNA: Boolean = skipNA_default, @@ -112,7 +111,7 @@ public inline fun DataFrame.meanOf( @Refine @Interpretable("GroupByMean1") public fun Grouped.mean(skipNA: Boolean = skipNA_default): DataFrame = - meanFor(skipNA, primitiveNumberColumns()) + meanFor(skipNA, primitiveOrMixedNumberColumns()) @Refine @Interpretable("GroupByMean0") @@ -177,7 +176,7 @@ public inline fun Grouped.meanOf( // region Pivot public fun Pivot.mean(skipNA: Boolean = skipNA_default, separate: Boolean = false): DataRow = - meanFor(skipNA, separate, primitiveNumberColumns()) + meanFor(skipNA, separate, primitiveOrMixedNumberColumns()) public fun Pivot.meanFor( skipNA: Boolean = skipNA_default, @@ -220,7 +219,7 @@ public inline fun Pivot.meanOf( // region PivotGroupBy public fun PivotGroupBy.mean(separate: Boolean = false, skipNA: Boolean = skipNA_default): DataFrame = - meanFor(skipNA, separate, primitiveNumberColumns()) + meanFor(skipNA, separate, primitiveOrMixedNumberColumns()) public fun PivotGroupBy.meanFor( skipNA: Boolean = skipNA_default, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt index 599e5497e5..c786320ac8 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/api/sum.kt @@ -21,14 +21,13 @@ import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateAll import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateFor import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOf import org.jetbrains.kotlinx.dataframe.impl.aggregation.modes.aggregateOfRow -import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveNumberColumns +import org.jetbrains.kotlinx.dataframe.impl.aggregation.primitiveOrMixedNumberColumns import org.jetbrains.kotlinx.dataframe.impl.columns.toNumberColumns -import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes +import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveOrMixedNumber import kotlin.experimental.ExperimentalTypeInference import kotlin.reflect.KClass import kotlin.reflect.KProperty import kotlin.reflect.KType -import kotlin.reflect.full.withNullability import kotlin.reflect.typeOf /* TODO KDocs @@ -70,7 +69,7 @@ public inline fun DataColumn.sumOf(crossinline expres // region DataRow -public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveNumberColumns()) +public fun AnyRow.rowSum(): Number = Aggregators.sum.aggregateOfRow(this, primitiveOrMixedNumberColumns()) @JvmName("rowSumOfShort") public inline fun AnyRow.rowSumOf(_kClass: KClass = Short::class): Int = @@ -98,7 +97,7 @@ public inline fun AnyRow.rowSumOf(_kClass: KClass // unfortunately, we cannot make a `reified T : Number?` due to clashes public fun AnyRow.rowSumOf(type: KType): Number { - require(type.withNullability(false) in primitiveNumberTypes) { + require(type.isPrimitiveOrMixedNumber()) { "Type $type is not a primitive number type. Mean only supports primitive number types." } return Aggregators.sum.aggregateOfRow(this) { colsOf(type) } @@ -107,7 +106,7 @@ public fun AnyRow.rowSumOf(type: KType): Number { // region DataFrame -public fun DataFrame.sum(): DataRow = sumFor(primitiveNumberColumns()) +public fun DataFrame.sum(): DataRow = sumFor(primitiveOrMixedNumberColumns()) public fun DataFrame.sumFor(columns: ColumnsForAggregateSelector): DataRow = Aggregators.sum.aggregateFor(this, columns) @@ -185,7 +184,7 @@ public inline fun DataFrame.sumOf(crossinline express // region GroupBy @Refine @Interpretable("GroupBySum1") -public fun Grouped.sum(): DataFrame = sumFor(primitiveNumberColumns()) +public fun Grouped.sum(): DataFrame = sumFor(primitiveOrMixedNumberColumns()) @Refine @Interpretable("GroupBySum0") @@ -229,7 +228,7 @@ public inline fun Grouped.sumOf( // region Pivot -public fun Pivot.sum(separate: Boolean = false): DataRow = sumFor(separate, primitiveNumberColumns()) +public fun Pivot.sum(separate: Boolean = false): DataRow = sumFor(separate, primitiveOrMixedNumberColumns()) public fun Pivot.sumFor( separate: Boolean = false, @@ -266,7 +265,8 @@ public inline fun Pivot.sumOf(crossinline expression: // region PivotGroupBy -public fun PivotGroupBy.sum(separate: Boolean = false): DataFrame = sumFor(separate, primitiveNumberColumns()) +public fun PivotGroupBy.sum(separate: Boolean = false): DataFrame = + sumFor(separate, primitiveOrMixedNumberColumns()) public fun PivotGroupBy.sumFor( separate: Boolean = false, diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt index f3f42ced5b..36d087b872 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/NumberTypeUtils.kt @@ -40,8 +40,8 @@ private val unifiedNumberTypeGraphs = mutableMapOf?, second: KClass<*>, options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, -): KClass<*> = +): KClass<*>? = when { first == null -> second - first == second -> first - else -> getUnifiedNumberClassGraph(options).findNearestCommonVertex(first, second) - ?: error("Can not find common number type for $first and $second") } /** @@ -156,28 +153,28 @@ internal fun getUnifiedNumberClass( * * @param options See [UnifiedNumberTypeOptions] * @return The nearest common numeric type between the input types. - * If no common type is found, it returns [Number]. + * If no common type is found, it returns `null`. * @see UnifyingNumbers */ -internal fun Iterable.unifiedNumberType( +internal fun Iterable.unifiedNumberTypeOrNull( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, -): KType = +): KType? = fold(null as KType?) { a, b -> - getUnifiedNumberType(a, b, options) - } ?: typeOf() + getUnifiedNumberTypeOrNull(a, b, options) ?: return null + } -/** @include [unifiedNumberType] */ -internal fun Iterable>.unifiedNumberClass( +/** @include [unifiedNumberTypeOrNull] */ +internal fun Iterable>.unifiedNumberClassOrNull( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, -): KClass<*> = +): KClass<*>? = fold(null as KClass<*>?) { a, b -> - getUnifiedNumberClass(a, b, options) - } ?: Number::class + getUnifiedNumberClassOrNull(a, b, options) ?: return null + } /** * Converts the elements of the given iterable of numbers into a common numeric type based on complexity. * The common numeric type is determined using the provided [commonNumberType] parameter - * or calculated with [Iterable.unifiedNumberType] from the iterable's elements if not explicitly specified. + * or calculated with [Iterable.unifiedNumberTypeOrNull] from the iterable's elements if not explicitly specified. * * @param commonNumberType The desired common numeric type to convert the elements to. * By default, (or if `null`), this is determined using the types of the elements in the iterable. @@ -191,7 +188,12 @@ internal fun Iterable.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, commonNumberType: KType? = null, ): Iterable { - val commonNumberType = commonNumberType ?: this.types().unifiedNumberType(options) + val commonNumberType = commonNumberType ?: this.types().let { types -> + types.unifiedNumberTypeOrNull(options) + ?: throw IllegalArgumentException( + "Cannot find unified number type of types: ${types.joinToString { renderType(it) }}", + ) + } val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { if (it == null) return@map null @@ -216,7 +218,12 @@ internal fun Sequence.convertToUnifiedNumberType( options: UnifiedNumberTypeOptions = UnifiedNumberTypeOptions.DEFAULT, commonNumberType: KType? = null, ): Sequence { - val commonNumberType = commonNumberType ?: this.asIterable().types().unifiedNumberType(options) + val commonNumberType = commonNumberType ?: this.asIterable().types().let { types -> + types.unifiedNumberTypeOrNull(options) + ?: throw IllegalArgumentException( + "Cannot find unified number type of types: ${types.joinToString { renderType(it) }}", + ) + } val converter = createConverter(typeOf(), commonNumberType)!! as (Number) -> Number? return map { if (it == null) return@map null @@ -245,7 +252,28 @@ internal val primitiveNumberTypes: Set = typeOf(), ) -internal fun Any.isPrimitiveNumber(): Boolean = +/** Returns `true` only when this type is exactly `Number` or `Number?`. */ +@PublishedApi +internal fun KType.isMixedNumber(): Boolean = this == typeOf() || this == typeOf() + +/** + * Returns `true` when this type is one of the following (nullable) types: + * [Byte], [Short], [Int], [Long], [Float], or [Double]. + */ +@PublishedApi +internal fun KType.isPrimitiveNumber(): Boolean = this.withNullability(false) in primitiveNumberTypes + +/** + * Returns `true` when this type is one of the following (nullable) types: + * [Byte], [Short], [Int], [Long], [Float], [Double], or [Number]. + * + * Careful: Will return `true` for `Number`. + * This type may arise as a supertype from multiple non-primitive number types. + */ +@PublishedApi +internal fun KType.isPrimitiveOrMixedNumber(): Boolean = isPrimitiveNumber() || isMixedNumber() + +internal fun Number.isPrimitiveNumber(): Boolean = this is Byte || this is Short || this is Int || diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt index d6b4b96ed4..d8e24f6d02 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/TypeUtils.kt @@ -477,7 +477,7 @@ internal fun guessValueType( it.isSubclassOf(Number::class) && it != nothingClass } if (usedNumberClasses.isNotEmpty()) { - val unifiedNumberClass = usedNumberClasses.unifiedNumberClass() as KClass + val unifiedNumberClass = usedNumberClasses.unifiedNumberClassOrNull() as KClass classes -= usedNumberClasses classes += unifiedNumberClass } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt index 31784c6036..8cfc4e92a7 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/aggregators/TwoStepNumbersAggregator.kt @@ -12,7 +12,7 @@ import org.jetbrains.kotlinx.dataframe.impl.nothingType import org.jetbrains.kotlinx.dataframe.impl.primitiveNumberTypes import org.jetbrains.kotlinx.dataframe.impl.renderType import org.jetbrains.kotlinx.dataframe.impl.types -import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberType +import org.jetbrains.kotlinx.dataframe.impl.unifiedNumberTypeOrNull import kotlin.reflect.KType import kotlin.reflect.full.isSubtypeOf import kotlin.reflect.full.starProjectedType @@ -95,11 +95,15 @@ internal class TwoStepNumbersAggregator( calculateReturnTypeOrNull(type = type.withNullability(false), emptyInput = colsEmpty) } if (typesAfterStepOne.anyNull()) return null - val commonType = (typesAfterStepOne as List) - .toSet() - .unifiedNumberType(PRIMITIVES_ONLY) - .withNullability(false) - return commonType + val typeSet = (typesAfterStepOne as List).toSet() + val unifiedType = typeSet.unifiedNumberTypeOrNull(PRIMITIVES_ONLY) + ?.withNullability(false) + ?: throw IllegalArgumentException( + "Cannot calculate the $name of the number types: ${typeSet.joinToString { renderType(it) }}. " + + "Note, only primitive number types are supported in statistics.", + ) + + return unifiedType } /** @@ -151,24 +155,28 @@ internal class TwoStepNumbersAggregator( @Suppress("UNCHECKED_CAST") override fun aggregateCalculatingType(values: Iterable, valueTypes: Set?): Return { val valueTypes = valueTypes?.takeUnless { it.isEmpty() } ?: values.types() - val commonType = valueTypes.unifiedNumberType(PRIMITIVES_ONLY) + val unifiedType = valueTypes.unifiedNumberTypeOrNull(PRIMITIVES_ONLY) + ?: throw IllegalArgumentException( + "Cannot calculate the $name of the number types: ${valueTypes.joinToString { renderType(it) }}. " + + "Note, only primitive number types are supported in statistics.", + ) - if (commonType.isSubtypeOf(typeOf()) && + if (unifiedType.isSubtypeOf(typeOf()) && (typeOf() in valueTypes || typeOf() in valueTypes) ) { logger.warn { "Number unification of Long -> Double happened during aggregation. Loss of precision may have occurred." } } - if (commonType.withNullability(false) !in primitiveNumberTypes && !commonType.isNothing) { + if (unifiedType.withNullability(false) !in primitiveNumberTypes && !unifiedType.isNothing) { throw IllegalArgumentException( - "Cannot calculate $name of ${renderType(commonType)}, only primitive numbers are supported.", + "Cannot calculate $name of ${renderType(unifiedType)}, only primitive numbers are supported.", ) } return super.aggregate( - values = values.convertToUnifiedNumberType(commonNumberType = commonType), - type = commonType, + values = values.convertToUnifiedNumberType(commonNumberType = unifiedType), + type = unifiedType, ) } diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt index 503efd92c5..136efa687e 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/impl/aggregation/getColumns.kt @@ -8,11 +8,10 @@ import org.jetbrains.kotlinx.dataframe.aggregation.NamedValue import org.jetbrains.kotlinx.dataframe.api.cast import org.jetbrains.kotlinx.dataframe.api.filter import org.jetbrains.kotlinx.dataframe.api.isNumber -import org.jetbrains.kotlinx.dataframe.api.isPrimitiveNumber +import org.jetbrains.kotlinx.dataframe.api.isPrimitiveOrMixedNumber import org.jetbrains.kotlinx.dataframe.api.valuesAreComparable import org.jetbrains.kotlinx.dataframe.columns.TypeSuggestion import org.jetbrains.kotlinx.dataframe.impl.columns.createColumnGuessingType -import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveNumber internal inline fun Aggregatable.remainingColumns( crossinline predicate: (AnyCol) -> Boolean, @@ -27,11 +26,11 @@ internal fun Aggregatable.numberColumns(): ColumnsSelector = remainingColumns { it.isNumber() } as ColumnsSelector @Suppress("UNCHECKED_CAST") -internal fun Aggregatable.primitiveNumberColumns(): ColumnsSelector = - remainingColumns { it.isPrimitiveNumber() } as ColumnsSelector +internal fun Aggregatable.primitiveOrMixedNumberColumns(): ColumnsSelector = + remainingColumns { it.isPrimitiveOrMixedNumber() } as ColumnsSelector -internal fun DataRow.primitiveNumberColumns(): ColumnsSelector = - { cols { it.isPrimitiveNumber() }.cast() } +internal fun DataRow.primitiveOrMixedNumberColumns(): ColumnsSelector = + { cols { it.isPrimitiveOrMixedNumber() }.cast() } internal fun NamedValue.toColumnWithPath() = path to createColumnGuessingType( diff --git a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt index 613c19ecb1..7ef56e1147 100644 --- a/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt +++ b/core/src/main/kotlin/org/jetbrains/kotlinx/dataframe/math/sum.kt @@ -50,9 +50,11 @@ internal val sumTypeConversion: CalculateReturnTypeOrNull = { type, _ -> typeOf(), typeOf() -> typeOf() // type remains the same - typeOf(), typeOf(), typeOf(), typeOf() -> type + typeOf(), typeOf(), typeOf(), typeOf(), typeOf() -> type - // defaults to Double - else -> typeOf() + nothingType -> typeOf() + + else -> + error("Unable to compute the sum for ${renderType(type)}, Only primitive numbers are supported.") } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt index 1980e5da70..0e158e73d4 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/statistics/sum.kt @@ -2,9 +2,10 @@ package org.jetbrains.kotlinx.dataframe.statistics import io.kotest.assertions.throwables.shouldThrow import io.kotest.matchers.shouldBe -import org.jetbrains.kotlinx.dataframe.DataColumn +import io.kotest.matchers.string.shouldContain import org.jetbrains.kotlinx.dataframe.api.columnOf import org.jetbrains.kotlinx.dataframe.api.dataFrameOf +import org.jetbrains.kotlinx.dataframe.api.isEmpty import org.jetbrains.kotlinx.dataframe.api.rowSum import org.jetbrains.kotlinx.dataframe.api.sum import org.jetbrains.kotlinx.dataframe.api.sumOf @@ -45,7 +46,7 @@ class SumTests { fun `test multiple columns`() { val value1 by columnOf(1, 2, 3) val value2 by columnOf(4.0, 5.0, 6.0) - val value3: DataColumn by columnOf(7.0, 8, null) + val value3 by columnOf(7.0, 8, null) val df = dataFrameOf(value1, value2, value3) val expected1 = 6 val expected2 = 15.0 @@ -88,8 +89,21 @@ class SumTests { @Test fun `unknown number type`() { + columnOf(1.toBigDecimal(), 2.toBigDecimal()).toDataFrame() + .sum() + .isEmpty() shouldBe true + } + + @Test + fun `mixed numbers`() { + // mixed number types are picked up implicitly + columnOf(1.0, 2).toDataFrame() + .sum()[0] shouldBe 3.0 + + // in the slight case a mixed number column contains unsupported numbers + // we give a helpful exception telling about primitive support only shouldThrow { - columnOf(1.toBigDecimal(), 2.toBigDecimal()).toDataFrame().sum() - } + columnOf(1.0, 2, 3.0.toBigDecimal()).toDataFrame().sum()[0] + }.message?.lowercase() shouldContain "primitive" } } diff --git a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt index 3518039dbf..3b40948732 100644 --- a/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt +++ b/core/src/test/kotlin/org/jetbrains/kotlinx/dataframe/types/UtilTests.kt @@ -11,7 +11,7 @@ import org.jetbrains.kotlinx.dataframe.impl.commonParents import org.jetbrains.kotlinx.dataframe.impl.commonType import org.jetbrains.kotlinx.dataframe.impl.commonTypeListifyValues import org.jetbrains.kotlinx.dataframe.impl.createType -import org.jetbrains.kotlinx.dataframe.impl.getUnifiedNumberClass +import org.jetbrains.kotlinx.dataframe.impl.getUnifiedNumberClassOrNull import org.jetbrains.kotlinx.dataframe.impl.guessValueType import org.jetbrains.kotlinx.dataframe.impl.isArray import org.jetbrains.kotlinx.dataframe.impl.isPrimitiveArray @@ -426,40 +426,40 @@ class UtilTests { @Test fun `common number types`() { // Same type - getUnifiedNumberClass(Int::class, Int::class) shouldBe Int::class - getUnifiedNumberClass(Double::class, Double::class) shouldBe Double::class + getUnifiedNumberClassOrNull(Int::class, Int::class) shouldBe Int::class + getUnifiedNumberClassOrNull(Double::class, Double::class) shouldBe Double::class // Direct parent-child relationships - getUnifiedNumberClass(Int::class, UShort::class) shouldBe Int::class - getUnifiedNumberClass(Long::class, UInt::class) shouldBe Long::class - getUnifiedNumberClass(Double::class, Float::class) shouldBe Double::class - getUnifiedNumberClass(UShort::class, Short::class) shouldBe Int::class - getUnifiedNumberClass(UByte::class, Byte::class) shouldBe Short::class + getUnifiedNumberClassOrNull(Int::class, UShort::class) shouldBe Int::class + getUnifiedNumberClassOrNull(Long::class, UInt::class) shouldBe Long::class + getUnifiedNumberClassOrNull(Double::class, Float::class) shouldBe Double::class + getUnifiedNumberClassOrNull(UShort::class, Short::class) shouldBe Int::class + getUnifiedNumberClassOrNull(UByte::class, Byte::class) shouldBe Short::class - getUnifiedNumberClass(UByte::class, UShort::class) shouldBe UShort::class + getUnifiedNumberClassOrNull(UByte::class, UShort::class) shouldBe UShort::class // Multi-level relationships - getUnifiedNumberClass(Byte::class, Int::class) shouldBe Int::class - getUnifiedNumberClass(UByte::class, Long::class) shouldBe Long::class - getUnifiedNumberClass(Short::class, Double::class) shouldBe Double::class - getUnifiedNumberClass(UInt::class, Int::class) shouldBe Long::class + getUnifiedNumberClassOrNull(Byte::class, Int::class) shouldBe Int::class + getUnifiedNumberClassOrNull(UByte::class, Long::class) shouldBe Long::class + getUnifiedNumberClassOrNull(Short::class, Double::class) shouldBe Double::class + getUnifiedNumberClassOrNull(UInt::class, Int::class) shouldBe Long::class // Top-level types - getUnifiedNumberClass(BigDecimal::class, Double::class) shouldBe BigDecimal::class - getUnifiedNumberClass(BigInteger::class, Long::class) shouldBe BigInteger::class - getUnifiedNumberClass(BigDecimal::class, BigInteger::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(BigDecimal::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(BigInteger::class, Long::class) shouldBe BigInteger::class + getUnifiedNumberClassOrNull(BigDecimal::class, BigInteger::class) shouldBe BigDecimal::class // Distant relationships - getUnifiedNumberClass(Byte::class, BigDecimal::class) shouldBe BigDecimal::class - getUnifiedNumberClass(UByte::class, Double::class) shouldBe Double::class + getUnifiedNumberClassOrNull(Byte::class, BigDecimal::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(UByte::class, Double::class) shouldBe Double::class // Complex type promotions - getUnifiedNumberClass(Int::class, Float::class) shouldBe Double::class - getUnifiedNumberClass(Long::class, Double::class) shouldBe BigDecimal::class - getUnifiedNumberClass(ULong::class, Double::class) shouldBe BigDecimal::class - getUnifiedNumberClass(BigInteger::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(Int::class, Float::class) shouldBe Double::class + getUnifiedNumberClassOrNull(Long::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(ULong::class, Double::class) shouldBe BigDecimal::class + getUnifiedNumberClassOrNull(BigInteger::class, Double::class) shouldBe BigDecimal::class // Edge case with null - getUnifiedNumberClass(null, Int::class) shouldBe Int::class + getUnifiedNumberClassOrNull(null, Int::class) shouldBe Int::class } }