22import json
33import jsonlines
44import os
5+ from args import args
56from pathlib import Path
67from PIL import Image
78from utils import get_datasets
89
910
10- # TODO: pass gs_url as a command line flag
11- # see https://cloud.google.com/docs/authentication/provide-credentials-adc to authorize
12- gs_url = "gs://shark-datasets/portraits"
13-
1411shark_root = Path (__file__ ).parent .parent
1512demo_css = shark_root .joinpath ("web/demo.css" ).resolve ()
1613nodlogo_loc = shark_root .joinpath (
3027 elem_id = "top_logo" ,
3128 ).style (width = 150 , height = 100 )
3229
33- datasets , images = get_datasets (gs_url )
30+ datasets , images , ds_w_prompts = get_datasets (args . gs_url )
3431 prompt_data = dict ()
3532
3633 with gr .Row (elem_id = "ui_body" ):
37- # TODO: add multiselect dataset
34+ # TODO: add multiselect dataset, there is a gradio version conflict
3835 dataset = gr .Dropdown (label = "Dataset" , choices = datasets )
3936 image_name = gr .Dropdown (label = "Image" , choices = [])
4037
41- with gr .Row (elem_id = "ui_body" , visible = True ):
38+ with gr .Row (elem_id = "ui_body" ):
4239 # TODO: add ability to search image by typing
4340 with gr .Column (scale = 1 , min_width = 600 ):
4441 image = gr .Image (type = "filepath" ).style (height = 512 )
6158 finish = gr .Button ("Finish" )
6259
6360 def filter_datasets (dataset ):
64- # TODO: execute finish process when switching dataset
6561 if dataset is None :
6662 return gr .Dropdown .update (value = None , choices = [])
6763
6864 # create the dataset dir if doesn't exist and download prompt file
6965 dataset_path = str (shark_root ) + "/dataset/" + dataset
70- # TODO: check if metadata.jsonl exists
71- prompt_gs_path = gs_url + "/" + dataset + "/metadata.jsonl"
7266 if not os .path .exists (dataset_path ):
7367 os .mkdir (dataset_path )
74- os .system (f'gsutil cp "{ prompt_gs_path } " "{ dataset_path } "/' )
7568
7669 # read prompt jsonlines file
7770 prompt_data .clear ()
78- with jsonlines .open (dataset_path + "/metadata.jsonl" ) as reader :
79- for line in reader .iter (type = dict , skip_invalid = True ):
80- prompt_data [line ["file_name" ]] = (
81- [line ["text" ]]
82- if type (line ["text" ]) is str
83- else line ["text" ]
84- )
71+ if dataset in ds_w_prompts :
72+ prompt_gs_path = args .gs_url + "/" + dataset + "/metadata.jsonl"
73+ os .system (f'gsutil cp "{ prompt_gs_path } " "{ dataset_path } "/' )
74+ with jsonlines .open (dataset_path + "/metadata.jsonl" ) as reader :
75+ for line in reader .iter (type = dict , skip_invalid = True ):
76+ prompt_data [line ["file_name" ]] = (
77+ [line ["text" ]]
78+ if type (line ["text" ]) is str
79+ else line ["text" ]
80+ )
8581
8682 return gr .Dropdown .update (choices = images [dataset ])
8783
@@ -92,8 +88,7 @@ def display_image(dataset, image_name):
9288 return gr .Image .update (value = None ), gr .Dropdown .update (value = None )
9389
9490 # download and load the image
95- # TODO: remove previous image if change image from dropdown
96- img_gs_path = gs_url + "/" + dataset + "/" + image_name
91+ img_gs_path = args .gs_url + "/" + dataset + "/" + image_name
9792 img_sub_path = "/" .join (image_name .split ("/" )[:- 1 ])
9893 img_dst_path = (
9994 str (shark_root ) + "/dataset/" + dataset + "/" + img_sub_path + "/"
@@ -103,6 +98,8 @@ def display_image(dataset, image_name):
10398 os .system (f'gsutil cp "{ img_gs_path } " "{ img_dst_path } "' )
10499 img = Image .open (img_dst_path + image_name .split ("/" )[- 1 ])
105100
101+ if image_name not in prompt_data .keys ():
102+ prompt_data [image_name ] = []
106103 prompt_choices = ["Add new" ]
107104 prompt_choices += prompt_data [image_name ]
108105 return gr .Image .update (value = img ), gr .Dropdown .update (
@@ -144,6 +141,8 @@ def save_prompt(dataset, image_name, prompts, prompt):
144141 # write prompt jsonlines file
145142 with open (prompt_path , "w" ) as f :
146143 for key , value in prompt_data .items ():
144+ if not value :
145+ continue
147146 v = value if len (value ) > 1 else value [0 ]
148147 f .write (json .dumps ({"file_name" : key , "text" : v }))
149148 f .write ("\n " )
@@ -171,6 +170,8 @@ def delete_prompt(dataset, image_name, prompts):
171170 # write prompt jsonlines file
172171 with open (prompt_path , "w" ) as f :
173172 for key , value in prompt_data .items ():
173+ if not value :
174+ continue
174175 v = value if len (value ) > 1 else value [0 ]
175176 f .write (json .dumps ({"file_name" : key , "text" : v }))
176177 f .write ("\n " )
@@ -227,7 +228,7 @@ def finish_annotation(dataset):
227228
228229 # upload prompt and remove local data
229230 dataset_path = str (shark_root ) + "/dataset/" + dataset
230- dataset_gs_path = gs_url + "/" + dataset + "/"
231+ dataset_gs_path = args . gs_url + "/" + dataset + "/"
231232 os .system (
232233 f'gsutil cp "{ dataset_path } /metadata.jsonl" "{ dataset_gs_path } "'
233234 )
@@ -240,8 +241,8 @@ def finish_annotation(dataset):
240241
241242if __name__ == "__main__" :
242243 shark_web .launch (
243- share = False ,
244+ share = args . share ,
244245 inbrowser = True ,
245246 server_name = "0.0.0.0" ,
246- server_port = 8080 ,
247+ server_port = args . server_port ,
247248 )
0 commit comments