Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/datasets/features/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,6 +215,7 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
The Arrow types that can be converted to the Image pyarrow storage type are:

- `pa.string()` - it must contain the "path" data
- `pa.large_string()` - it must contain the "path" data (will be cast to string if possible)
- `pa.binary()` - it must contain the image bytes
- `pa.struct({"bytes": pa.binary()})`
- `pa.struct({"path": pa.string()})`
Expand All @@ -229,6 +230,15 @@ def cast_storage(self, storage: Union[pa.StringArray, pa.StructArray, pa.ListArr
`pa.StructArray`: Array in the Image arrow storage type, that is
`pa.struct({"bytes": pa.binary(), "path": pa.string()})`.
"""
if pa.types.is_large_string(storage.type):
try:
storage = storage.cast(pa.string())
except pa.ArrowInvalid as e:
raise ValueError(
f"Failed to cast large_string to string for Image feature. "
f"This can happen if string values exceed 2GB. "
f"Original error: {e}"
) from e
if pa.types.is_string(storage.type):
bytes_array = pa.array([None] * len(storage), type=pa.binary())
storage = pa.StructArray.from_arrays([bytes_array, storage], ["bytes", "path"], mask=storage.is_null())
Expand Down
12 changes: 12 additions & 0 deletions tests/features/test_image.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,6 +320,18 @@ def test_dataset_cast_to_image_features(shared_datadir, build_data):
assert isinstance(item["image"], PIL.Image.Image)


def test_dataset_cast_to_image_features_polars(shared_datadir):
import PIL.Image

pl = pytest.importorskip("polars")
image_path = str(shared_datadir / "test_image_rgb.jpg")
df = pl.DataFrame({"image_path": [image_path]})
dataset = Dataset.from_polars(df)
item = dataset.cast_column("image_path", Image())[0]
assert item.keys() == {"image_path"}
assert isinstance(item["image_path"], PIL.Image.Image)


@require_pil
def test_dataset_concatenate_image_features(shared_datadir):
# we use a different data structure between 1 and 2 to make sure they are compatible with each other
Expand Down
2 changes: 1 addition & 1 deletion tests/test_download_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ def test_download_manager_delete_extracted_files(xz_file):
assert extracted_path == dl_manager.extracted_paths[xz_file]
extracted_path = Path(extracted_path)
parts = extracted_path.parts
# import pdb; pdb.set_trace()

assert parts[-1] == hash_url_to_filename(str(xz_file), etag=None)
assert parts[-2] == extracted_subdir
assert extracted_path.exists()
Expand Down
Loading