Skip to content

Commit 20182a3

Browse files
authored
convert StabilityAI to use new API client (#10582)
1 parent 5f109fe commit 20182a3

File tree

2 files changed

+49
-136
lines changed

2 files changed

+49
-136
lines changed

comfy_api_nodes/nodes_stability.py

Lines changed: 47 additions & 134 deletions
Original file line numberDiff line numberDiff line change
@@ -20,20 +20,16 @@
2020
StabilityAudioInpaintRequest,
2121
StabilityAudioResponse,
2222
)
23-
from comfy_api_nodes.apis.client import (
24-
ApiEndpoint,
25-
HttpMethod,
26-
SynchronousOperation,
27-
PollingOperation,
28-
EmptyRequest,
29-
)
3023
from comfy_api_nodes.util import (
3124
validate_audio_duration,
3225
validate_string,
3326
audio_input_to_mp3,
3427
bytesio_to_image_tensor,
3528
tensor_to_bytesio,
3629
audio_bytes_to_audio_input,
30+
sync_op,
31+
poll_op,
32+
ApiEndpoint,
3733
)
3834

3935
import torch
@@ -161,19 +157,11 @@ async def execute(
161157
"image": image_binary
162158
}
163159

164-
auth = {
165-
"auth_token": cls.hidden.auth_token_comfy_org,
166-
"comfy_api_key": cls.hidden.api_key_comfy_org,
167-
}
168-
169-
operation = SynchronousOperation(
170-
endpoint=ApiEndpoint(
171-
path="/proxy/stability/v2beta/stable-image/generate/ultra",
172-
method=HttpMethod.POST,
173-
request_model=StabilityStableUltraRequest,
174-
response_model=StabilityStableUltraResponse,
175-
),
176-
request=StabilityStableUltraRequest(
160+
response_api = await sync_op(
161+
cls,
162+
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/ultra", method="POST"),
163+
response_model=StabilityStableUltraResponse,
164+
data=StabilityStableUltraRequest(
177165
prompt=prompt,
178166
negative_prompt=negative_prompt,
179167
aspect_ratio=aspect_ratio,
@@ -183,9 +171,7 @@ async def execute(
183171
),
184172
files=files,
185173
content_type="multipart/form-data",
186-
auth_kwargs=auth,
187174
)
188-
response_api = await operation.execute()
189175

190176
if response_api.finish_reason != "SUCCESS":
191177
raise Exception(f"Stable Image Ultra generation failed: {response_api.finish_reason}.")
@@ -313,19 +299,11 @@ async def execute(
313299
"image": image_binary
314300
}
315301

316-
auth = {
317-
"auth_token": cls.hidden.auth_token_comfy_org,
318-
"comfy_api_key": cls.hidden.api_key_comfy_org,
319-
}
320-
321-
operation = SynchronousOperation(
322-
endpoint=ApiEndpoint(
323-
path="/proxy/stability/v2beta/stable-image/generate/sd3",
324-
method=HttpMethod.POST,
325-
request_model=StabilityStable3_5Request,
326-
response_model=StabilityStableUltraResponse,
327-
),
328-
request=StabilityStable3_5Request(
302+
response_api = await sync_op(
303+
cls,
304+
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/generate/sd3", method="POST"),
305+
response_model=StabilityStableUltraResponse,
306+
data=StabilityStable3_5Request(
329307
prompt=prompt,
330308
negative_prompt=negative_prompt,
331309
aspect_ratio=aspect_ratio,
@@ -338,9 +316,7 @@ async def execute(
338316
),
339317
files=files,
340318
content_type="multipart/form-data",
341-
auth_kwargs=auth,
342319
)
343-
response_api = await operation.execute()
344320

345321
if response_api.finish_reason != "SUCCESS":
346322
raise Exception(f"Stable Diffusion 3.5 Image generation failed: {response_api.finish_reason}.")
@@ -427,29 +403,19 @@ async def execute(
427403
"image": image_binary
428404
}
429405

430-
auth = {
431-
"auth_token": cls.hidden.auth_token_comfy_org,
432-
"comfy_api_key": cls.hidden.api_key_comfy_org,
433-
}
434-
435-
operation = SynchronousOperation(
436-
endpoint=ApiEndpoint(
437-
path="/proxy/stability/v2beta/stable-image/upscale/conservative",
438-
method=HttpMethod.POST,
439-
request_model=StabilityUpscaleConservativeRequest,
440-
response_model=StabilityStableUltraResponse,
441-
),
442-
request=StabilityUpscaleConservativeRequest(
406+
response_api = await sync_op(
407+
cls,
408+
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/conservative", method="POST"),
409+
response_model=StabilityStableUltraResponse,
410+
data=StabilityUpscaleConservativeRequest(
443411
prompt=prompt,
444412
negative_prompt=negative_prompt,
445413
creativity=round(creativity,2),
446414
seed=seed,
447415
),
448416
files=files,
449417
content_type="multipart/form-data",
450-
auth_kwargs=auth,
451418
)
452-
response_api = await operation.execute()
453419

454420
if response_api.finish_reason != "SUCCESS":
455421
raise Exception(f"Stability Upscale Conservative generation failed: {response_api.finish_reason}.")
@@ -544,19 +510,11 @@ async def execute(
544510
"image": image_binary
545511
}
546512

547-
auth = {
548-
"auth_token": cls.hidden.auth_token_comfy_org,
549-
"comfy_api_key": cls.hidden.api_key_comfy_org,
550-
}
551-
552-
operation = SynchronousOperation(
553-
endpoint=ApiEndpoint(
554-
path="/proxy/stability/v2beta/stable-image/upscale/creative",
555-
method=HttpMethod.POST,
556-
request_model=StabilityUpscaleCreativeRequest,
557-
response_model=StabilityAsyncResponse,
558-
),
559-
request=StabilityUpscaleCreativeRequest(
513+
response_api = await sync_op(
514+
cls,
515+
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/creative", method="POST"),
516+
response_model=StabilityAsyncResponse,
517+
data=StabilityUpscaleCreativeRequest(
560518
prompt=prompt,
561519
negative_prompt=negative_prompt,
562520
creativity=round(creativity,2),
@@ -565,25 +523,15 @@ async def execute(
565523
),
566524
files=files,
567525
content_type="multipart/form-data",
568-
auth_kwargs=auth,
569526
)
570-
response_api = await operation.execute()
571-
572-
operation = PollingOperation(
573-
poll_endpoint=ApiEndpoint(
574-
path=f"/proxy/stability/v2beta/results/{response_api.id}",
575-
method=HttpMethod.GET,
576-
request_model=EmptyRequest,
577-
response_model=StabilityResultsGetResponse,
578-
),
527+
528+
response_poll = await poll_op(
529+
cls,
530+
ApiEndpoint(path=f"/proxy/stability/v2beta/results/{response_api.id}"),
531+
response_model=StabilityResultsGetResponse,
579532
poll_interval=3,
580-
completed_statuses=[StabilityPollStatus.finished],
581-
failed_statuses=[StabilityPollStatus.failed],
582533
status_extractor=lambda x: get_async_dummy_status(x),
583-
auth_kwargs=auth,
584-
node_id=cls.hidden.unique_id,
585534
)
586-
response_poll: StabilityResultsGetResponse = await operation.execute()
587535

588536
if response_poll.finish_reason != "SUCCESS":
589537
raise Exception(f"Stability Upscale Creative generation failed: {response_poll.finish_reason}.")
@@ -628,24 +576,13 @@ async def execute(cls, image: torch.Tensor) -> IO.NodeOutput:
628576
"image": image_binary
629577
}
630578

631-
auth = {
632-
"auth_token": cls.hidden.auth_token_comfy_org,
633-
"comfy_api_key": cls.hidden.api_key_comfy_org,
634-
}
635-
636-
operation = SynchronousOperation(
637-
endpoint=ApiEndpoint(
638-
path="/proxy/stability/v2beta/stable-image/upscale/fast",
639-
method=HttpMethod.POST,
640-
request_model=EmptyRequest,
641-
response_model=StabilityStableUltraResponse,
642-
),
643-
request=EmptyRequest(),
579+
response_api = await sync_op(
580+
cls,
581+
ApiEndpoint(path="/proxy/stability/v2beta/stable-image/upscale/fast", method="POST"),
582+
response_model=StabilityStableUltraResponse,
644583
files=files,
645584
content_type="multipart/form-data",
646-
auth_kwargs=auth,
647585
)
648-
response_api = await operation.execute()
649586

650587
if response_api.finish_reason != "SUCCESS":
651588
raise Exception(f"Stability Upscale Fast failed: {response_api.finish_reason}.")
@@ -717,21 +654,13 @@ def define_schema(cls):
717654
async def execute(cls, model: str, prompt: str, duration: int, seed: int, steps: int) -> IO.NodeOutput:
718655
validate_string(prompt, max_length=10000)
719656
payload = StabilityTextToAudioRequest(prompt=prompt, model=model, duration=duration, seed=seed, steps=steps)
720-
operation = SynchronousOperation(
721-
endpoint=ApiEndpoint(
722-
path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio",
723-
method=HttpMethod.POST,
724-
request_model=StabilityTextToAudioRequest,
725-
response_model=StabilityAudioResponse,
726-
),
727-
request=payload,
657+
response_api = await sync_op(
658+
cls,
659+
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/text-to-audio", method="POST"),
660+
response_model=StabilityAudioResponse,
661+
data=payload,
728662
content_type="multipart/form-data",
729-
auth_kwargs= {
730-
"auth_token": cls.hidden.auth_token_comfy_org,
731-
"comfy_api_key": cls.hidden.api_key_comfy_org,
732-
},
733663
)
734-
response_api = await operation.execute()
735664
if not response_api.audio:
736665
raise ValueError("No audio file was received in response.")
737666
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -814,22 +743,14 @@ async def execute(
814743
payload = StabilityAudioToAudioRequest(
815744
prompt=prompt, model=model, duration=duration, seed=seed, steps=steps, strength=strength
816745
)
817-
operation = SynchronousOperation(
818-
endpoint=ApiEndpoint(
819-
path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio",
820-
method=HttpMethod.POST,
821-
request_model=StabilityAudioToAudioRequest,
822-
response_model=StabilityAudioResponse,
823-
),
824-
request=payload,
746+
response_api = await sync_op(
747+
cls,
748+
ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/audio-to-audio", method="POST"),
749+
response_model=StabilityAudioResponse,
750+
data=payload,
825751
content_type="multipart/form-data",
826752
files={"audio": audio_input_to_mp3(audio)},
827-
auth_kwargs= {
828-
"auth_token": cls.hidden.auth_token_comfy_org,
829-
"comfy_api_key": cls.hidden.api_key_comfy_org,
830-
},
831753
)
832-
response_api = await operation.execute()
833754
if not response_api.audio:
834755
raise ValueError("No audio file was received in response.")
835756
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))
@@ -935,22 +856,14 @@ async def execute(
935856
mask_start=mask_start,
936857
mask_end=mask_end,
937858
)
938-
operation = SynchronousOperation(
939-
endpoint=ApiEndpoint(
940-
path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint",
941-
method=HttpMethod.POST,
942-
request_model=StabilityAudioInpaintRequest,
943-
response_model=StabilityAudioResponse,
944-
),
945-
request=payload,
859+
response_api = await sync_op(
860+
cls,
861+
endpoint=ApiEndpoint(path="/proxy/stability/v2beta/audio/stable-audio-2/inpaint", method="POST"),
862+
response_model=StabilityAudioResponse,
863+
data=payload,
946864
content_type="multipart/form-data",
947865
files={"audio": audio_input_to_mp3(audio)},
948-
auth_kwargs={
949-
"auth_token": cls.hidden.auth_token_comfy_org,
950-
"comfy_api_key": cls.hidden.api_key_comfy_org,
951-
},
952866
)
953-
response_api = await operation.execute()
954867
if not response_api.audio:
955868
raise ValueError("No audio file was received in response.")
956869
return IO.NodeOutput(audio_bytes_to_audio_input(base64.b64decode(response_api.audio)))

comfy_api_nodes/util/client.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ class _PollUIState:
7777

7878

7979
_RETRY_STATUS = {408, 429, 500, 502, 503, 504}
80-
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed"]
80+
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished"]
8181
FAILED_STATUSES = ["cancelled", "canceled", "fail", "failed", "error"]
8282
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted"]
8383

@@ -589,7 +589,7 @@ async def _monitor(stop_evt: asyncio.Event, start_ts: float):
589589
operation_id = _generate_operation_id(method, cfg.endpoint.path, attempt)
590590
logging.debug("[DEBUG] HTTP %s %s (attempt %d)", method, url, attempt)
591591

592-
payload_headers = {"Accept": "*/*"}
592+
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
593593
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
594594
payload_headers.update(get_auth_header(cfg.node_cls))
595595
if cfg.endpoint.headers:

0 commit comments

Comments
 (0)