Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 22 additions & 3 deletions daras_ai_v2/enum_selector_widget.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ def enum_multiselect(
label: str = "",
checkboxes=True,
allow_none=True,
tooltip: dict[E, str] | None = None,
tooltip_placement: str = "top",
):
try:
deprecated = enum_cls._deprecated()
Expand All @@ -27,8 +29,14 @@ def enum_multiselect(
if e in deprecated and e.name not in value:
continue
enums.append(e)
enum_names = [e.name for e in enums]
enum_labels = {e.name: _default_format_func(e) for e in enums}

enum_names = []
enum_labels = {}
enum_lookup = {}
for e in enums:
enum_names.append(e.name)
enum_labels[e.name] = _default_format_func(e)
enum_lookup[e.name] = e

if checkboxes:
if label:
Expand All @@ -41,7 +49,18 @@ def render(name):
if inner_key not in gui.session_state:
gui.session_state[inner_key] = name in selected

gui.checkbox(enum_labels.get(name), key=inner_key)
enum_obj = enum_lookup[name]
tooltip_text = tooltip.get(enum_obj) if tooltip else None

if tooltip_text:
gui.checkbox(
enum_labels.get(name),
key=inner_key,
help=tooltip_text,
tooltip_placement=tooltip_placement,
)
else:
gui.checkbox(enum_labels.get(name), key=inner_key)

if gui.session_state.get(inner_key):
ret_val.append(name)
Expand Down
7 changes: 7 additions & 0 deletions daras_ai_v2/stable_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,13 @@ def _deprecated(cls):
}


model_pricing_tooltips = {
Text2ImgModels.dall_e_3: "15 Cr",
Text2ImgModels.gpt_image_1: "3, 10 or 40 Cr",
Text2ImgModels.nano_banana: "8 Cr",
}


class Img2ImgModels(Enum):
flux_pro_kontext = "FLUX.1 Pro Kontext (fal.ai)"

Expand Down
9 changes: 5 additions & 4 deletions recipes/CompareText2Img.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,8 @@ def related_workflows(self) -> list:
]

def render_form_v2(self):
from daras_ai_v2.stable_diffusion import model_pricing_tooltips

gui.text_area(
"""
#### 👩‍💻 Prompt
Expand All @@ -118,10 +120,6 @@ def render_form_v2(self):
gui.caption(
"""
Each selected model costs 2 credits ($.02) / image except where noted.

Dalle-3: 15 Cr
Nano banana: 8 Cr
GPT-image: 3, 10 or 40 Cr
"""
)

Expand All @@ -131,9 +129,12 @@ def render_form_v2(self):
[Check out our prompt guide](https://docs.google.com/presentation/d/1RaoMP0l7FnBZovDAR42zVmrUND9W5DW6eWet-pi6kiE/edit#slide=id.g210b1678eba_0_26).
"""
)

selected_models = enum_multiselect(
Text2ImgModels,
key="selected_models",
tooltip=model_pricing_tooltips,
tooltip_placement="right",
)
if selected_models and set(selected_models) <= {Text2ImgModels.flux_1_dev.name}:
loras_input()
Expand Down
Loading