diff --git a/spec/API_specification/dataframe_api/__init__.py b/spec/API_specification/dataframe_api/__init__.py index 2c6bc15b..5d9d3400 100644 --- a/spec/API_specification/dataframe_api/__init__.py +++ b/spec/API_specification/dataframe_api/__init__.py @@ -15,6 +15,7 @@ from .typing import DType, Scalar __all__ = [ + "Aggregation", "Bool", "Column", "DataFrame", diff --git a/spec/API_specification/dataframe_api/groupby_object.py b/spec/API_specification/dataframe_api/groupby_object.py index 062bb2d5..750f4383 100644 --- a/spec/API_specification/dataframe_api/groupby_object.py +++ b/spec/API_specification/dataframe_api/groupby_object.py @@ -6,7 +6,10 @@ from .dataframe_object import DataFrame -__all__ = ['GroupBy'] +__all__ = [ + "Aggregation", + "GroupBy", +] class GroupBy(Protocol): @@ -51,3 +54,75 @@ def var(self, *, correction: int | float = 1, skip_nulls: bool = True) -> DataFr def size(self) -> DataFrame: ... + + def aggregate(self, *aggregation: Aggregation) -> DataFrame: + """ + Aggregate columns according to given aggregation function. + + Examples + -------- + >>> df: DataFrame + >>> namespace = df.__dataframe_namespace__() + >>> df.group_by('year').aggregate( + ... namespace.Aggregation.sum('l_quantity').rename('sum_qty'), + ... namespace.Aggregation.mean('l_quantity').rename('avg_qty'), + ... namespace.Aggregation.mean('l_extended_price').rename('avg_price'), + ... namespace.Aggregation.mean('l_discount').rename('avg_disc'), + ... namespace.Aggregation.size().rename('count_order'), + ... ) + """ + ... + +class Aggregation(Protocol): + def rename(self, name: str) -> Aggregation: + """ + Assign given name to output of aggregation. + + If not called, the column's name will be used as the output name. + """ + ... + + @classmethod + def any(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def all(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def min(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def max(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def sum(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def prod(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def median(cls, column: str, *, skip_nulls: bool = True) -> Aggregation: + ... + + @classmethod + def mean(cls, column: str, *, skip_nulls: bool=True) -> Aggregation: + ... + + @classmethod + def std(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation: + ... + + @classmethod + def var(cls, column: str, *, correction: int|float=1, skip_nulls: bool=True) -> Aggregation: + ... + + @classmethod + def size(cls) -> Aggregation: + ... + diff --git a/spec/API_specification/dataframe_api/typing.py b/spec/API_specification/dataframe_api/typing.py index d1c3b20a..5efeb1b1 100644 --- a/spec/API_specification/dataframe_api/typing.py +++ b/spec/API_specification/dataframe_api/typing.py @@ -15,7 +15,7 @@ from dataframe_api.column_object import Column from dataframe_api.dataframe_object import DataFrame -from dataframe_api.groupby_object import GroupBy +from dataframe_api.groupby_object import GroupBy, Aggregation as AggregationT if TYPE_CHECKING: from .dtypes import ( @@ -112,6 +112,8 @@ def __init__( class String(): ... + Aggregation: AggregationT + def concat(self, dataframes: Sequence[DataFrame]) -> DataFrame: ... @@ -146,7 +148,7 @@ def is_null(self, value: object, /) -> bool: def is_dtype(self, dtype: Any, kind: str | tuple[str, ...]) -> bool: ... - + def date(self, year: int, month: int, day: int) -> Scalar: ... @@ -164,6 +166,7 @@ def __column_consortium_standard__( __all__ = [ + "Aggregation", "Column", "DataFrame", "DType", diff --git a/spec/API_specification/examples/tpch/q1.py b/spec/API_specification/examples/tpch/q1.py new file mode 100644 index 00000000..b5c11287 --- /dev/null +++ b/spec/API_specification/examples/tpch/q1.py @@ -0,0 +1,37 @@ +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from dataframe_api.typing import SupportsDataFrameAPI + + +def query(lineitem_raw: SupportsDataFrameAPI) -> Any: + lineitem = lineitem_raw.__dataframe_consortium_standard__() + namespace = lineitem.__dataframe_namespace__() + + mask = lineitem.get_column_by_name("l_shipdate") <= namespace.date(1998, 9, 2) + lineitem = lineitem.assign( + ( + lineitem.get_column_by_name("l_extended_price") + * (1 - lineitem.get_column_by_name("l_discount")) + ).rename("l_disc_price"), + ( + lineitem.get_column_by_name("l_extended_price") + * (1 - lineitem.get_column_by_name("l_discount")) + * (1 + lineitem.get_column_by_name("l_tax")) + ).rename("l_charge"), + ) + result = ( + lineitem.filter(mask) + .group_by("l_returnflag", "l_linestatus") + .aggregate( + namespace.Aggregation.sum("l_quantity").rename("sum_qty"), + namespace.Aggregation.sum("l_extendedprice").rename("sum_base_price"), + namespace.Aggregation.sum("l_disc_price").rename("sum_disc_price"), + namespace.Aggregation.sum("change").rename("sum_charge"), + namespace.Aggregation.mean("l_quantity").rename("avg_qty"), + namespace.Aggregation.mean("l_discount").rename("avg_disc"), + namespace.Aggregation.size().rename("count_order"), + ) + .sort("l_returnflag", "l_linestatus") + ) + return result.dataframe diff --git a/spec/API_specification/examples/tpch/q5.py b/spec/API_specification/examples/tpch/q5.py index 9109bcd5..332967c7 100644 --- a/spec/API_specification/examples/tpch/q5.py +++ b/spec/API_specification/examples/tpch/q5.py @@ -68,7 +68,6 @@ def query( * (1 - result.get_column_by_name("l_discount")) ).rename("revenue") result = result.assign(new_column) - result = result.select("revenue", "n_name") - result = result.group_by("n_name").sum() + result = result.group_by("n_name").aggregate(namespace.Aggregation.sum("revenue")) return result.dataframe diff --git a/spec/conf.py b/spec/conf.py index 9b9a405a..cc6e3270 100644 --- a/spec/conf.py +++ b/spec/conf.py @@ -84,6 +84,7 @@ ('py:class', 'Scalar'), ('py:class', 'Bool'), ('py:class', 'optional'), + ('py:class', 'Aggregation'), ('py:class', 'NullType'), ('py:class', 'Namespace'), ('py:class', 'SupportsDataFrameAPI'),