@@ -782,9 +782,11 @@ def __init__(
782782 poll_endpoint : ApiEndpoint [EmptyRequest , R ],
783783 completed_statuses : list [str ],
784784 failed_statuses : list [str ],
785+ * ,
785786 status_extractor : Callable [[R ], Optional [str ]],
786787 progress_extractor : Callable [[R ], Optional [float ]] | None = None ,
787788 result_url_extractor : Callable [[R ], Optional [str ]] | None = None ,
789+ price_extractor : Callable [[R ], Optional [float ]] | None = None ,
788790 request : Optional [T ] = None ,
789791 api_base : str | None = None ,
790792 auth_token : Optional [str ] = None ,
@@ -815,10 +817,12 @@ def __init__(
815817 self .status_extractor = status_extractor or (lambda x : getattr (x , "status" , None ))
816818 self .progress_extractor = progress_extractor
817819 self .result_url_extractor = result_url_extractor
820+ self .price_extractor = price_extractor
818821 self .node_id = node_id
819822 self .completed_statuses = completed_statuses
820823 self .failed_statuses = failed_statuses
821824 self .final_response : Optional [R ] = None
825+ self .extracted_price : Optional [float ] = None
822826
823827 async def execute (self , client : Optional [ApiClient ] = None ) -> R :
824828 owns_client = client is None
@@ -840,6 +844,8 @@ async def execute(self, client: Optional[ApiClient] = None) -> R:
840844 def _display_text_on_node (self , text : str ):
841845 if not self .node_id :
842846 return
847+ if self .extracted_price is not None :
848+ text = f"Price: { self .extracted_price } $\n { text } "
843849 PromptServer .instance .send_progress_text (text , self .node_id )
844850
845851 def _display_time_progress_on_node (self , time_completed : int | float ):
@@ -877,9 +883,7 @@ async def _poll_until_complete(self, client: ApiClient) -> R:
877883 try :
878884 logging .debug ("[DEBUG] Polling attempt #%s" , poll_count )
879885
880- request_dict = (
881- None if self .request is None else self .request .model_dump (exclude_none = True )
882- )
886+ request_dict = None if self .request is None else self .request .model_dump (exclude_none = True )
883887
884888 if poll_count == 1 :
885889 logging .debug (
@@ -912,6 +916,11 @@ async def _poll_until_complete(self, client: ApiClient) -> R:
912916 if new_progress is not None :
913917 progress .update_absolute (new_progress , total = PROGRESS_BAR_MAX )
914918
919+ if self .price_extractor :
920+ price = self .price_extractor (response_obj )
921+ if price is not None :
922+ self .extracted_price = price
923+
915924 if status == TaskStatus .COMPLETED :
916925 message = "Task completed successfully"
917926 if self .result_url_extractor :
0 commit comments