diff --git a/pyiceberg/io/pyarrow.py b/pyiceberg/io/pyarrow.py index 04f30ec63e..71925c27cd 100644 --- a/pyiceberg/io/pyarrow.py +++ b/pyiceberg/io/pyarrow.py @@ -469,15 +469,18 @@ def __setstate__(self, state: Dict[str, Any]) -> None: self.fs_by_scheme = lru_cache(self._initialize_fs) -def schema_to_pyarrow(schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT) -> pa.schema: - return visit(schema, _ConvertToArrowSchema(metadata)) +def schema_to_pyarrow( + schema: Union[Schema, IcebergType], metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True +) -> pa.schema: + return visit(schema, _ConvertToArrowSchema(metadata, include_field_ids)) class _ConvertToArrowSchema(SchemaVisitorPerPrimitiveType[pa.DataType]): _metadata: Dict[bytes, bytes] - def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT) -> None: + def __init__(self, metadata: Dict[bytes, bytes] = EMPTY_DICT, include_field_ids: bool = True) -> None: self._metadata = metadata + self._include_field_ids = include_field_ids def schema(self, _: Schema, struct_result: pa.StructType) -> pa.schema: return pa.schema(list(struct_result), metadata=self._metadata) @@ -486,13 +489,17 @@ def struct(self, _: StructType, field_results: List[pa.DataType]) -> pa.DataType return pa.struct(field_results) def field(self, field: NestedField, field_result: pa.DataType) -> pa.Field: + metadata = {} + if field.doc: + metadata[PYARROW_FIELD_DOC_KEY] = field.doc + if self._include_field_ids: + metadata[PYARROW_PARQUET_FIELD_ID_KEY] = str(field.field_id) + return pa.field( name=field.name, type=field_result, nullable=field.optional, - metadata={PYARROW_FIELD_DOC_KEY: field.doc, PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)} - if field.doc - else {PYARROW_PARQUET_FIELD_ID_KEY: str(field.field_id)}, + metadata=metadata, ) def list(self, list_type: ListType, element_result: pa.DataType) -> pa.DataType: @@ -1130,7 +1137,7 @@ def project_table( tables = [f.result() for f in completed_futures if f.result()] if len(tables) < 1: - return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema)) + return pa.Table.from_batches([], schema=schema_to_pyarrow(projected_schema, include_field_ids=False)) result = pa.concat_tables(tables) @@ -1161,7 +1168,7 @@ def __init__(self, file_schema: Schema): def _cast_if_needed(self, field: NestedField, values: pa.Array) -> pa.Array: file_field = self.file_schema.find_field(field.field_id) if field.field_type.is_primitive and field.field_type != file_field.field_type: - return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type))) + return values.cast(schema_to_pyarrow(promote(file_field.field_type, field.field_type), include_field_ids=False)) return values def _construct_field(self, field: NestedField, arrow_type: pa.DataType) -> pa.Field: @@ -1188,7 +1195,7 @@ def struct( field_arrays.append(array) fields.append(self._construct_field(field, array.type)) elif field.optional: - arrow_type = schema_to_pyarrow(field.field_type) + arrow_type = schema_to_pyarrow(field.field_type, include_field_ids=False) field_arrays.append(pa.nulls(len(struct_array), type=arrow_type)) fields.append(self._construct_field(field, arrow_type)) else: diff --git a/tests/io/test_pyarrow.py b/tests/io/test_pyarrow.py index ec511f959d..baa9e30824 100644 --- a/tests/io/test_pyarrow.py +++ b/tests/io/test_pyarrow.py @@ -344,7 +344,7 @@ def test_deleting_hdfs_file_not_found() -> None: assert "Cannot delete file, does not exist:" in str(exc_info.value) -def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None: +def test_schema_to_pyarrow_schema_include_field_ids(table_schema_nested: Schema) -> None: actual = schema_to_pyarrow(table_schema_nested) expected = """foo: string -- field metadata -- @@ -402,6 +402,30 @@ def test_schema_to_pyarrow_schema(table_schema_nested: Schema) -> None: assert repr(actual) == expected +def test_schema_to_pyarrow_schema_exclude_field_ids(table_schema_nested: Schema) -> None: + actual = schema_to_pyarrow(table_schema_nested, include_field_ids=False) + expected = """foo: string +bar: int32 not null +baz: bool +qux: list not null + child 0, element: string not null +quux: map> not null + child 0, entries: struct not null> not null + child 0, key: string not null + child 1, value: map not null + child 0, entries: struct not null + child 0, key: string not null + child 1, value: int32 not null +location: list not null> not null + child 0, element: struct not null + child 0, latitude: float + child 1, longitude: float +person: struct + child 0, name: string + child 1, age: int32 not null""" + assert repr(actual) == expected + + def test_fixed_type_to_pyarrow() -> None: length = 22 iceberg_type = FixedType(length) @@ -945,23 +969,13 @@ def test_projection_add_column(file_int: str) -> None: == """id: int32 list: list child 0, element: int32 - -- field metadata -- - PARQUET:field_id: '21' map: map child 0, entries: struct not null child 0, key: int32 not null - -- field metadata -- - PARQUET:field_id: '31' child 1, value: string - -- field metadata -- - PARQUET:field_id: '32' location: struct child 0, lat: double - -- field metadata -- - PARQUET:field_id: '41' - child 1, lon: double - -- field metadata -- - PARQUET:field_id: '42'""" + child 1, lon: double""" ) @@ -1014,11 +1028,7 @@ def test_projection_add_column_struct(schema_int: Schema, file_int: str) -> None == """id: map child 0, entries: struct not null child 0, key: int32 not null - -- field metadata -- - PARQUET:field_id: '3' - child 1, value: string - -- field metadata -- - PARQUET:field_id: '4'""" + child 1, value: string""" ) @@ -1062,12 +1072,7 @@ def test_projection_concat_files(schema_int: Schema, file_int: str) -> None: def test_projection_filter(schema_int: Schema, file_int: str) -> None: result_table = project(schema_int, [file_int], GreaterThan("id", 4)) assert len(result_table.columns[0]) == 0 - assert ( - repr(result_table.schema) - == """id: int32 - -- field metadata -- - PARQUET:field_id: '1'""" - ) + assert repr(result_table.schema) == """id: int32""" def test_projection_filter_renamed_column(file_int: str) -> None: @@ -1304,11 +1309,7 @@ def test_projection_nested_struct_different_parent_id(file_struct: str) -> None: repr(result_table.schema) == """location: struct child 0, lat: double - -- field metadata -- - PARQUET:field_id: '41' - child 1, long: double - -- field metadata -- - PARQUET:field_id: '42'""" + child 1, long: double""" )