@@ -776,8 +776,10 @@ def _download_from_youtube(path):
776776 if accelerator .is_main_process :
777777 force_download = dataset_kwargs .get ("force_download" , False )
778778 force_unzip = dataset_kwargs .get ("force_unzip" , False )
779+ print (force_download )
779780 cache_path = snapshot_download (repo_id = self .DATASET_PATH , repo_type = "dataset" , force_download = force_download , etag_timeout = 60 )
780781 zip_files = glob (os .path .join (cache_path , "**/*.zip" ), recursive = True )
782+ tar_files = glob (os .path .join (cache_path , "**/*.tar*" ), recursive = True )
781783
782784 def unzip_video_data (zip_file ):
783785 import zipfile
@@ -786,10 +788,57 @@ def unzip_video_data(zip_file):
786788 zip_ref .extractall (cache_dir )
787789 eval_logger .info (f"Extracted all files from { zip_file } to { cache_dir } " )
788790
791+ def untar_video_data (tar_file ):
792+ import tarfile
793+ with tarfile .open (tar_file , "r" ) as tar_ref :
794+ tar_ref .extractall (cache_dir )
795+ eval_logger .info (f"Extracted all files from { tar_file } to { cache_dir } " )
796+
797+
798+
799+ def concat_tar_parts (tar_parts , output_tar ):
800+ print ("This is the output file:" , output_tar , "from:" , tar_parts )
801+ try :
802+ with open (output_tar , 'wb' ) as out_tar :
803+ from tqdm import tqdm
804+ for part in tqdm (sorted (tar_parts )):
805+ with open (part , 'rb' ) as part_file :
806+ out_tar .write (part_file .read ())
807+ except Exception as ex :
808+ print ("Error!!!" , ex )
809+ eval_logger .info (f"Concatenated parts { tar_parts } into { output_tar } " )
810+
811+ # Unzip zip files if needed
789812 if force_unzip or (not os .path .exists (cache_dir ) and len (zip_files ) > 0 ):
790813 for zip_file in zip_files :
791814 unzip_video_data (zip_file )
792815
816+ # Concatenate and extract tar files if needed
817+ if force_unzip or (not os .path .exists (cache_dir ) and len (tar_files ) > 0 ):
818+ tar_parts_dict = {}
819+
820+ # Group tar parts together
821+ for tar_file in tar_files :
822+ base_name = tar_file .split ('.tar' )[0 ]
823+ if base_name not in tar_parts_dict :
824+ tar_parts_dict [base_name ] = []
825+ tar_parts_dict [base_name ].append (tar_file )
826+
827+ print (tar_parts_dict )
828+
829+ # Concatenate and untar split parts
830+ for base_name , parts in tar_parts_dict .items ():
831+ eval_logger .info (f"Extracting following tar files: { parts } " )
832+ output_tar = base_name + ".tar"
833+ if not os .path .exists (output_tar ):
834+ eval_logger .info (f"Start concatenating tar files" )
835+
836+ concat_tar_parts (parts , output_tar )
837+ eval_logger .info (f"Finish concatenating tar files" )
838+
839+ if not os .path .exists (os .path .join (cache_dir , os .path .basename (base_name ))):
840+ untar_video_data (output_tar )
841+
793842 accelerator .wait_for_everyone ()
794843 dataset_kwargs .pop ("cache_dir" )
795844 dataset_kwargs .pop ("video" )
0 commit comments