88
99from keras .src import tree
1010from keras .src .api_export import keras_export
11+ from keras .src .utils import file_utils
1112from keras .src .utils import io_utils
13+ from keras .src .utils .module_utils import grain
1214from keras .src .utils .module_utils import tensorflow as tf
1315
1416
@@ -299,6 +301,17 @@ def is_torch_dataset(dataset):
299301 return False
300302
301303
304+ def is_grain_dataset (dataset ):
305+ if hasattr (dataset , "__class__" ):
306+ for parent in dataset .__class__ .__mro__ :
307+ if parent .__name__ in (
308+ "MapDataset" ,
309+ "IterDataset" ,
310+ ) and str (parent .__module__ ).startswith ("grain._src.python" ):
311+ return True
312+ return False
313+
314+
302315def _rescale_dataset_split_sizes (left_size , right_size , total_length ):
303316 """Rescale the dataset split sizes.
304317
@@ -476,6 +489,10 @@ def _get_type_spec(dataset):
476489 from torch .utils .data import Dataset as TorchDataset
477490
478491 return TorchDataset
492+ elif is_grain_dataset (dataset ):
493+ from grain import MapDataset
494+
495+ return MapDataset
479496 else :
480497 return None
481498
@@ -525,10 +542,17 @@ def index_directory(
525542 - class_names: names of the classes corresponding to these labels, in
526543 order.
527544 """
545+ if file_utils .is_remote_path (directory ):
546+ os_module = tf .io .gfile
547+ path_module = tf .io .gfile
548+ else :
549+ os_module = os
550+ path_module = os .path
551+
528552 if labels == "inferred" :
529553 subdirs = []
530- for subdir in sorted (tf . io . gfile .listdir (directory )):
531- if tf . io . gfile . isdir (tf . io . gfile .join (directory , subdir )):
554+ for subdir in sorted (os_module .listdir (directory )):
555+ if path_module . isdir (path_module .join (directory , subdir )):
532556 if not subdir .startswith ("." ):
533557 if subdir .endswith ("/" ):
534558 subdir = subdir [:- 1 ]
@@ -566,7 +590,7 @@ def index_directory(
566590 results = []
567591 filenames = []
568592
569- for dirpath in (tf . io . gfile .join (directory , subdir ) for subdir in subdirs ):
593+ for dirpath in (path_module .join (directory , subdir ) for subdir in subdirs ):
570594 results .append (
571595 pool .apply_async (
572596 index_subdirectory ,
@@ -608,7 +632,7 @@ def index_directory(
608632 )
609633 pool .close ()
610634 pool .join ()
611- file_paths = [tf . io . gfile .join (directory , fname ) for fname in filenames ]
635+ file_paths = [path_module .join (directory , fname ) for fname in filenames ]
612636
613637 if shuffle :
614638 # Shuffle globally to erase macro-structure
@@ -623,8 +647,10 @@ def index_directory(
623647
624648
625649def iter_valid_files (directory , follow_links , formats ):
650+ io_module = tf .io .gfile if file_utils .is_remote_path (directory ) else os
651+
626652 if not follow_links :
627- walk = tf . io . gfile .walk (directory )
653+ walk = io_module .walk (directory )
628654 else :
629655 walk = os .walk (directory , followlinks = follow_links )
630656 for root , _ , files in sorted (walk , key = lambda x : x [0 ]):
@@ -648,14 +674,18 @@ def index_subdirectory(directory, class_indices, follow_links, formats):
648674 paths, and `labels` is a list of integer labels corresponding
649675 to these files.
650676 """
677+ path_module = (
678+ tf .io .gfile if file_utils .is_remote_path (directory ) else os .path
679+ )
680+
651681 dirname = os .path .basename (directory )
652682 valid_files = iter_valid_files (directory , follow_links , formats )
653683 labels = []
654684 filenames = []
655685 for root , fname in valid_files :
656686 labels .append (class_indices [dirname ])
657- absolute_path = tf . io . gfile .join (root , fname )
658- relative_path = tf . io . gfile .join (
687+ absolute_path = path_module .join (root , fname )
688+ relative_path = path_module .join (
659689 dirname , os .path .relpath (absolute_path , directory )
660690 )
661691 filenames .append (relative_path )
@@ -700,7 +730,7 @@ def get_training_or_validation_split(samples, labels, validation_split, subset):
700730 return samples , labels
701731
702732
703- def labels_to_dataset (labels , label_mode , num_classes ):
733+ def labels_to_dataset_tf (labels , label_mode , num_classes ):
704734 """Create a `tf.data.Dataset` from the list/tuple of labels.
705735
706736 Args:
@@ -730,6 +760,51 @@ def labels_to_dataset(labels, label_mode, num_classes):
730760 return label_ds
731761
732762
763+ def labels_to_dataset_grain (labels , label_mode , num_classes ):
764+ """Create a `grain.MapDataset` from the list/tuple of labels.
765+
766+ Args:
767+ labels: list/tuple of labels to be converted into a `grain.MapDataset`.
768+ label_mode: String describing the encoding of `labels`. Options are:
769+ - `"binary"` indicates that the labels (there can be only 2) are encoded
770+ as `float32` scalars with values 0 or 1
771+ (e.g. for `binary_crossentropy`).
772+ - `"categorical"` means that the labels are mapped into a categorical
773+ vector. (e.g. for `categorical_crossentropy` loss).
774+ num_classes: number of classes of labels.
775+
776+ Returns:
777+ A `grain.MapDataset` instance.
778+ """
779+ from keras .src import backend
780+ from keras .src import ops
781+
782+ if label_mode not in ("binary" , "categorical" , "int" ):
783+ raise ValueError (
784+ f"Invalid `label_mode`: { label_mode } . "
785+ "Expected one of: 'binary', 'categorical', 'int'."
786+ )
787+
788+ def preprocess_labels_in_cpu (label_mode , x , num_classes ):
789+ with backend .device_scope ("cpu" ):
790+ if label_mode == "binary" :
791+ return ops .expand_dims (
792+ ops .convert_to_tensor (x , dtype = "float32" ), axis = - 1
793+ )
794+ elif label_mode == "categorical" :
795+ return ops .one_hot (
796+ ops .convert_to_tensor (x , dtype = "int32" ), num_classes
797+ )
798+ else :
799+ return ops .convert_to_tensor (x , dtype = "int32" )
800+
801+ label_ds = grain .MapDataset .source (labels )
802+ label_ds = label_ds .map (
803+ lambda x : preprocess_labels_in_cpu (label_mode , x , num_classes ),
804+ )
805+ return label_ds
806+
807+
733808def check_validation_split_arg (validation_split , subset , shuffle , seed ):
734809 """Raise errors in case of invalid argument values.
735810
0 commit comments