@@ -782,9 +782,11 @@ def __init__(
782782        poll_endpoint : ApiEndpoint [EmptyRequest , R ],
783783        completed_statuses : list [str ],
784784        failed_statuses : list [str ],
785+         * ,
785786        status_extractor : Callable [[R ], Optional [str ]],
786787        progress_extractor : Callable [[R ], Optional [float ]] |  None  =  None ,
787788        result_url_extractor : Callable [[R ], Optional [str ]] |  None  =  None ,
789+         price_extractor : Callable [[R ], Optional [float ]] |  None  =  None ,
788790        request : Optional [T ] =  None ,
789791        api_base : str  |  None  =  None ,
790792        auth_token : Optional [str ] =  None ,
@@ -815,10 +817,12 @@ def __init__(
815817        self .status_extractor  =  status_extractor  or  (lambda  x : getattr (x , "status" , None ))
816818        self .progress_extractor  =  progress_extractor 
817819        self .result_url_extractor  =  result_url_extractor 
820+         self .price_extractor  =  price_extractor 
818821        self .node_id  =  node_id 
819822        self .completed_statuses  =  completed_statuses 
820823        self .failed_statuses  =  failed_statuses 
821824        self .final_response : Optional [R ] =  None 
825+         self .extracted_price : Optional [float ] =  None 
822826
823827    async  def  execute (self , client : Optional [ApiClient ] =  None ) ->  R :
824828        owns_client  =  client  is  None 
@@ -840,6 +844,8 @@ async def execute(self, client: Optional[ApiClient] = None) -> R:
840844    def  _display_text_on_node (self , text : str ):
841845        if  not  self .node_id :
842846            return 
847+         if  self .extracted_price  is  not None :
848+             text  =  f"Price: { self .extracted_price } \n { text }  
843849        PromptServer .instance .send_progress_text (text , self .node_id )
844850
845851    def  _display_time_progress_on_node (self , time_completed : int  |  float ):
@@ -877,9 +883,7 @@ async def _poll_until_complete(self, client: ApiClient) -> R:
877883            try :
878884                logging .debug ("[DEBUG] Polling attempt #%s" , poll_count )
879885
880-                 request_dict  =  (
881-                     None  if  self .request  is  None  else  self .request .model_dump (exclude_none = True )
882-                 )
886+                 request_dict  =  None  if  self .request  is  None  else  self .request .model_dump (exclude_none = True )
883887
884888                if  poll_count  ==  1 :
885889                    logging .debug (
@@ -912,6 +916,11 @@ async def _poll_until_complete(self, client: ApiClient) -> R:
912916                    if  new_progress  is  not None :
913917                        progress .update_absolute (new_progress , total = PROGRESS_BAR_MAX )
914918
919+                 if  self .price_extractor :
920+                     price  =  self .price_extractor (response_obj )
921+                     if  price  is  not None :
922+                         self .extracted_price  =  price 
923+ 
915924                if  status  ==  TaskStatus .COMPLETED :
916925                    message  =  "Task completed successfully" 
917926                    if  self .result_url_extractor :
0 commit comments