diff --git a/pandas/compat/pyarrow.py b/pandas/compat/pyarrow.py index cc5c7a2e51976..9bf7139769baa 100644 --- a/pandas/compat/pyarrow.py +++ b/pandas/compat/pyarrow.py @@ -11,8 +11,10 @@ pa_version_under2p0 = _palv < Version("2.0.0") pa_version_under3p0 = _palv < Version("3.0.0") pa_version_under4p0 = _palv < Version("4.0.0") + pa_version_under5p0 = _palv < Version("5.0.0") except ImportError: pa_version_under1p0 = True pa_version_under2p0 = True pa_version_under3p0 = True pa_version_under4p0 = True + pa_version_under5p0 = True diff --git a/pandas/tests/io/test_parquet.py b/pandas/tests/io/test_parquet.py index 12a79f68d71c8..58aef2f2844df 100644 --- a/pandas/tests/io/test_parquet.py +++ b/pandas/tests/io/test_parquet.py @@ -17,6 +17,7 @@ from pandas.compat.pyarrow import ( pa_version_under1p0, pa_version_under2p0, + pa_version_under5p0, ) import pandas.util._test_decorators as td @@ -222,6 +223,29 @@ def compare(repeat): compare(repeat) +def check_partition_names(path, expected): + """Check partitions of a parquet file are as expected. + + Parameters + ---------- + path: str + Path of the dataset. + expected: iterable of str + Expected partition names. + """ + if pa_version_under5p0: + import pyarrow.parquet as pq + + dataset = pq.ParquetDataset(path, validate_schema=False) + assert len(dataset.partitions.partition_names) == len(expected) + assert dataset.partitions.partition_names == set(expected) + else: + import pyarrow.dataset as ds + + dataset = ds.dataset(path, partitioning="hive") + assert dataset.partitioning.schema.names == expected + + def test_invalid_engine(df_compat): msg = "engine must be one of 'pyarrow', 'fastparquet'" with pytest.raises(ValueError, match=msg): @@ -743,11 +767,7 @@ def test_partition_cols_supported(self, pa, df_full): df = df_full with tm.ensure_clean_dir() as path: df.to_parquet(path, partition_cols=partition_cols, compression=None) - import pyarrow.parquet as pq - - dataset = pq.ParquetDataset(path, validate_schema=False) - assert len(dataset.partitions.partition_names) == 2 - assert dataset.partitions.partition_names == set(partition_cols) + check_partition_names(path, partition_cols) assert read_parquet(path).shape == df.shape def test_partition_cols_string(self, pa, df_full): @@ -757,11 +777,7 @@ def test_partition_cols_string(self, pa, df_full): df = df_full with tm.ensure_clean_dir() as path: df.to_parquet(path, partition_cols=partition_cols, compression=None) - import pyarrow.parquet as pq - - dataset = pq.ParquetDataset(path, validate_schema=False) - assert len(dataset.partitions.partition_names) == 1 - assert dataset.partitions.partition_names == set(partition_cols_list) + check_partition_names(path, partition_cols_list) assert read_parquet(path).shape == df.shape @pytest.mark.parametrize("path_type", [str, pathlib.Path])