diff --git a/protocol/pandas_implementation.py b/protocol/pandas_implementation.py index a016d25c..b3d8943e 100644 --- a/protocol/pandas_implementation.py +++ b/protocol/pandas_implementation.py @@ -97,6 +97,20 @@ class _DtypeKind(enum.IntEnum): DATETIME = 22 CATEGORICAL = 23 +class _Device(enum.IntEnum): + CPU = 1 + CUDA = 2 + CPU_PINNED = 3 + OPENCL = 4 + VULKAN = 7 + METAL = 8 + VPI = 9 + ROCM = 10 + +_INTS = {8: np.int8, 16: np.int16, 32: np.int32, 64: np.int64} +_UNITS = {8: np.uint8, 16: np.uint16, 32: np.uint32, 64: np.uint64} +_FLOATS = {32: np.float32, 64: np.float64} +_NP_DTYPES = {0: _INTS, 1: _UNITS, 2: _FLOATS, 20: {8: bool}} def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray: """ @@ -108,24 +122,17 @@ def convert_column_to_ndarray(col : ColumnObject) -> np.ndarray: if col.describe_null[0] not in (0, 1): raise NotImplementedError("Null values represented as masks or " "sentinel values not handled yet") - - _buffer, _dtype = col.get_buffers()["data"] - return buffer_to_ndarray(_buffer, _dtype), _buffer + buffers = col.get_buffers() + _buffer, _dtype = buffers["data"] + # there is a strange side effect (failing unit test) when replacing below + # `buffers` by `col.get_buffers()`. It is like the buffer has changed between + # the `buffer_to_ndarray` call and `col.get_buffers()` + return buffer_to_ndarray(_buffer, _dtype), buffers def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray: - # Handle the dtype - kind = _dtype[0] bitwidth = _dtype[1] - _k = _DtypeKind - if _dtype[0] not in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL): - raise RuntimeError("Not a boolean, integer or floating-point dtype") - - _ints = {8: np.int8, 16: np.int16, 32: np.int32, 64: np.int64} - _uints = {8: np.uint8, 16: np.uint16, 32: np.uint32, 64: np.uint64} - _floats = {32: np.float32, 64: np.float64} - _np_dtypes = {0: _ints, 1: _uints, 2: _floats, 20: {8: bool}} - column_dtype = _np_dtypes[kind][bitwidth] + column_dtype = protocol_dtype_to_np_dtype(_dtype) # No DLPack yet, so need to construct a new ndarray from the data pointer # and size in the buffer plus the dtype on the column @@ -140,6 +147,14 @@ def buffer_to_ndarray(_buffer, _dtype) -> np.ndarray: return x +def protocol_dtype_to_np_dtype(_dtype): + kind = _dtype[0] + bitwidth = _dtype[1] + _k = _DtypeKind + if _dtype[0] not in (_k.INT, _k.UINT, _k.FLOAT, _k.BOOL): + raise RuntimeError("Not a boolean, integer or floating-point dtype") + + return _NP_DTYPES[kind][bitwidth] def convert_categorical_column(col : ColumnObject) -> pd.Series: """ @@ -153,7 +168,8 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series: # categories = col._col.values.categories.values # codes = col._col.values.codes categories = np.asarray(list(mapping.values())) - codes_buffer, codes_dtype = col.get_buffers()["data"] + buffers = col.get_buffers() + codes_buffer, codes_dtype = buffers["data"] codes = buffer_to_ndarray(codes_buffer, codes_dtype) values = categories[codes] @@ -169,7 +185,7 @@ def convert_categorical_column(col : ColumnObject) -> pd.Series: raise NotImplementedError("Only categorical columns with sentinel " "value supported at the moment") - return series, codes_buffer + return series, buffers def convert_string_column(col : ColumnObject) -> np.ndarray: @@ -309,10 +325,7 @@ def __dlpack_device__(self) -> Tuple[enum.IntEnum, int]: """ Device type and device ID for where the data in the buffer resides. """ - class Device(enum.IntEnum): - CPU = 1 - - return (Device.CPU, None) + return (_Device.CPU, None) def __repr__(self) -> str: return 'PandasBuffer(' + str({'bufsize': self.bufsize,