diff --git a/src/iartisanxl/diffusers_patch/ip_adapter_attention_processor.py b/src/iartisanxl/diffusers_patch/ip_adapter_attention_processor.py index b6cc36b..0f48c58 100644 --- a/src/iartisanxl/diffusers_patch/ip_adapter_attention_processor.py +++ b/src/iartisanxl/diffusers_patch/ip_adapter_attention_processor.py @@ -109,7 +109,7 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module): the weight scale of image prompt. """ - def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0): + def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale=1.0, block_transformer_name=None): super().__init__() if not hasattr(F, "scaled_dot_product_attention"): @@ -120,6 +120,7 @@ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=(4,), scale self.hidden_size = hidden_size self.cross_attention_dim = cross_attention_dim self.num_tokens = num_tokens + self.block_transformer_name = block_transformer_name if not isinstance(scale, list): scale = [scale] * len(num_tokens) @@ -226,7 +227,15 @@ def __call__( current_ip_hidden_states = current_ip_hidden_states * mask_downsample - hidden_states = hidden_states + scale * current_ip_hidden_states + scale_value = scale + + if isinstance(scale, dict): + if self.block_transformer_name in scale: + scale_value = scale[self.block_transformer_name] + else: + scale_value = 1.0 + + hidden_states = hidden_states + scale_value * current_ip_hidden_states # linear proj hidden_states = attn.to_out[0](hidden_states) diff --git a/src/iartisanxl/graph/nodes/ip_adapter_merge_node.py b/src/iartisanxl/graph/nodes/ip_adapter_merge_node.py index 0654a83..33883dd 100644 --- a/src/iartisanxl/graph/nodes/ip_adapter_merge_node.py +++ b/src/iartisanxl/graph/nodes/ip_adapter_merge_node.py @@ -7,9 +7,9 @@ class IPAdapterMergeNode(Node): OUTPUTS = ["ip_adapter"] def __call__(self) -> dict: - self.unet.set_attn_processor(AttnProcessor2_0()) - - if self.ip_adapter is not None: + if self.ip_adapter is None: + self.unet.set_attn_processor(AttnProcessor2_0()) + else: ip_adapters = self.ip_adapter if isinstance(ip_adapters, dict): @@ -17,13 +17,30 @@ def __call__(self) -> dict: weights = [] scales = [] + reload_weights = False for ip_adapter in ip_adapters: + if ip_adapter.get("reload_weights", False): + reload_weights = True + ip_adapter["reload_weights"] = False + weights.append(ip_adapter["weights"]) - scales.append(ip_adapter["scale"]) - attn_procs = self.convert_ip_adapter_attn_to_diffusers(weights) - self.unet.set_attn_processor(attn_procs) + scale = 0.0 + + if ip_adapter.get("enabled", False): + scale = ( + ip_adapter["granular_scale"] + if ip_adapter.get("granular_scale_enabled", False) + else ip_adapter.get("scale", 0.0) + ) + + scales.append(scale) + + if reload_weights: + self.unet.set_attn_processor(AttnProcessor2_0()) + attn_procs = self.convert_ip_adapter_attn_to_diffusers(weights) + self.unet.set_attn_processor(attn_procs) for attn_processor in self.unet.attn_processors.values(): if isinstance(attn_processor, IPAdapterAttnProcessor2_0): @@ -69,11 +86,15 @@ def convert_ip_adapter_attn_to_diffusers(self, state_dicts): # IP-Adapter Plus num_image_text_embeds += [state_dict["image_proj"]["latents"].shape[1]] + name_parts = name.split(".") + block_transformer_name = ".".join(name_parts[:4]) + attn_procs[name] = attn_processor_class( hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0, num_tokens=num_image_text_embeds, + block_transformer_name=block_transformer_name, ).to(dtype=self.torch_dtype, device=self.device) value_dict = {} diff --git a/src/iartisanxl/graph/nodes/ip_adapter_node.py b/src/iartisanxl/graph/nodes/ip_adapter_node.py index 457cf9b..085a39e 100644 --- a/src/iartisanxl/graph/nodes/ip_adapter_node.py +++ b/src/iartisanxl/graph/nodes/ip_adapter_node.py @@ -13,21 +13,44 @@ class IPAdapterNode(Node): OPTIONAL_INPUTS = ["mask_alpha_image"] OUTPUTS = ["ip_adapter"] - def __init__(self, type_index: int, adapter_type: str, adapter_scale: float = None, **kwargs): + def __init__( + self, + type_index: int, + adapter_type: str, + adapter_scale: float = None, + granular_scale_enabled: bool = False, + granular_scale: dict = None, + **kwargs, + ): super().__init__(**kwargs) self.type_index = type_index self.adapter_type = adapter_type self.adapter_scale = adapter_scale + self.adapter_granuler_scale = granular_scale + self.granular_scale_enabled = granular_scale_enabled self.ip_image_prompt_embeds = None + self.reload_weights = True self.clip_image_processor = CLIPImageProcessor() - def update_adapter(self, type_index: int, adapter_type: str, enabled: bool, adapter_scale: float = None): + def update_adapter( + self, + type_index: int, + adapter_type: str, + enabled: bool, + adapter_scale: float = None, + granular_scale_enabled: bool = False, + granular_scale: dict = None, + reload_weights: bool = False, + ): self.type_index = type_index self.adapter_type = adapter_type self.enabled = enabled self.adapter_scale = adapter_scale + self.granular_scale_enabled = granular_scale_enabled + self.adapter_granuler_scale = granular_scale + self.reload_weights = reload_weights self.set_updated() def to_dict(self): @@ -35,6 +58,8 @@ def to_dict(self): node_dict["type_index"] = self.type_index node_dict["adapter_type"] = self.adapter_type node_dict["adapter_scale"] = self.adapter_scale + node_dict["granular_scale_enabled"] = self.granular_scale_enabled + node_dict["adapter_granuler_scale"] = self.adapter_granuler_scale return node_dict @classmethod @@ -43,12 +68,16 @@ def from_dict(cls, node_dict, _callbacks=None): node.type_index = node_dict["type_index"] node.adapter_type = node_dict["adapter_type"] node.adapter_scale = node_dict["adapter_scale"] + node.granular_scale_enabled = node_dict["granular_scale_enabled"] + node.adapter_granuler_scale = node_dict["adapter_granuler_scale"] return node def update_inputs(self, node_dict): self.type_index = node_dict["type_index"] self.adapter_type = node_dict["adapter_type"] self.adapter_scale = node_dict["adapter_scale"] + self.granular_scale_enabled = node_dict["granular_scale_enabled"] + self.adapter_granuler_scale = node_dict["adapter_granuler_scale"] def __call__(self) -> dict: image_projection = self.convert_ip_adapter_image_proj_to_diffusers(self.ip_adapter_model["image_proj"]) @@ -65,7 +94,7 @@ def __call__(self) -> dict: image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(image, output_hidden_states) # save_embeds = torch.cat([uncond_image_prompt_embeds, image_prompt_embeds]) - # torch.save(save_embeds, "C:/Users/Ozzy/Desktop/iartisanxl_style_test.ipadpt") + # torch.save(save_embeds, "iartisanxl_style_test.ipadpt") tensor_mask = None if self.mask_alpha_image is not None: @@ -81,9 +110,13 @@ def __call__(self) -> dict: "weights": self.ip_adapter_model, "image_prompt_embeds": image_prompt_embeds, "uncond_image_prompt_embeds": uncond_image_prompt_embeds, + "enabled": self.enabled, "scale": self.adapter_scale, + "granular_scale_enabled": self.granular_scale_enabled, + "granular_scale": self.adapter_granuler_scale, "tensor_mask": tensor_mask, "image_projection": image_projection, + "reload_weights": self.reload_weights, } return self.values diff --git a/src/iartisanxl/modules/common/ip_adapter/advanced_widget.py b/src/iartisanxl/modules/common/ip_adapter/advanced_widget.py new file mode 100644 index 0000000..83467a3 --- /dev/null +++ b/src/iartisanxl/modules/common/ip_adapter/advanced_widget.py @@ -0,0 +1,114 @@ +from PyQt6.QtCore import Qt, pyqtSignal +from PyQt6.QtWidgets import QCheckBox, QFrame, QHBoxLayout, QLabel, QPushButton, QVBoxLayout, QWidget +from superqt import QLabeledDoubleSlider + +from iartisanxl.modules.common.ip_adapter.ip_adapter_data_object import IPAdapterDataObject + + +class AdvancedWidget(QWidget): + advanced_canceled = pyqtSignal() + granular_enabled = pyqtSignal(bool) + + def __init__(self, ip_adapter: IPAdapterDataObject): + super().__init__() + + self.ip_adapter = ip_adapter + self.attention_values = { + "down_1": [1.0, 1.0], + "down_2": [1.0, 1.0], + "mid": [1.0], + "up_0": [1.0, 1.0, 1.0], + "up_1": [1.0, 1.0, 1.0], + } + + self.sliders = {} + self.frames = [] + + self.init_ui() + + def init_ui(self): + main_layout = QVBoxLayout() + + granular_scales_checkbox = QCheckBox("Enable granular scales") + granular_scales_checkbox.stateChanged.connect(self.on_granular) + main_layout.addWidget(granular_scales_checkbox) + + sections_layout = QHBoxLayout() + + for section, values in self.attention_values.items(): + frame = QFrame() + frame.setDisabled(True) + frame.setObjectName("block_frame") + + blocks_layout = QVBoxLayout() + section_label = QLabel(f"{section.capitalize()} Blocks") + blocks_layout.addWidget(section_label) + + # Loop and create all the sliders for the section + for i, value in enumerate(values): + attention_layout = QHBoxLayout() + attention_label = QLabel( + f"Attention {i+1}" + ) # this is the number of the count in the total attention vars + attention_layout.addWidget(attention_label) + attention_slider = QLabeledDoubleSlider(Qt.Orientation.Horizontal) + attention_slider.setRange(0.0, 1.0) + attention_slider.setValue(value) + attention_slider.valueChanged.connect(lambda val, sec=section, idx=i: self.update_scale(val, sec, idx)) + attention_layout.addWidget(attention_slider) + blocks_layout.addLayout(attention_layout) + + self.sliders.setdefault(section, []).append(attention_slider) + + frame.setLayout(blocks_layout) + sections_layout.addWidget(frame) + self.frames.append(frame) + + main_layout.addLayout(sections_layout) + main_layout.addStretch() + + button_layout = QHBoxLayout() + save_button = QPushButton("Set scales") + save_button.clicked.connect(self.on_save) + button_layout.addWidget(save_button) + cancel_button = QPushButton("Cancel") + cancel_button.clicked.connect(self.on_cancel) + button_layout.addWidget(cancel_button) + + main_layout.addLayout(button_layout) + + self.setLayout(main_layout) + + def update_scale(self, value, section, index): + self.attention_values[section][index] = value + + def on_cancel(self): + self.advanced_canceled.emit() + + def on_save(self): + scales = {} + for section, values in self.attention_values.items(): + if section.startswith("down_"): + block = "down_blocks" + block_num = section.split("_")[1] + elif section == "mid": + block = "mid_block" + block_num = "" + elif section.startswith("up_"): + block = "up_blocks" + block_num = section.split("_")[1] + + for i, value in enumerate(values): + key = f"{block}.{block_num}.attentions.{i}" + scales[key] = value + + self.ip_adapter.granular_scale = scales + self.advanced_canceled.emit() + + def on_granular(self, state): + is_enabled = state != Qt.CheckState.Unchecked.value + + for frame in self.frames: + frame.setEnabled(is_enabled) + + self.granular_enabled.emit(is_enabled) diff --git a/src/iartisanxl/modules/common/ip_adapter/ip_adapter_data_object.py b/src/iartisanxl/modules/common/ip_adapter/ip_adapter_data_object.py index f184de0..2c18c59 100644 --- a/src/iartisanxl/modules/common/ip_adapter/ip_adapter_data_object.py +++ b/src/iartisanxl/modules/common/ip_adapter/ip_adapter_data_object.py @@ -13,6 +13,8 @@ class IPAdapterDataObject: adapter_type: str = attr.ib(default=None) type_index: int = attr.ib(default=0) ip_adapter_scale: float = attr.ib(default=1.0) + granular_scale_enabled: bool = attr.ib(default=False) + granular_scale: dict = attr.ib(default=None) enabled: bool = attr.ib(default=True) node_id: int = attr.ib(default=None) adapter_id: int = attr.ib(default=None) diff --git a/src/iartisanxl/modules/common/ip_adapter/ip_adapter_dialog.py b/src/iartisanxl/modules/common/ip_adapter/ip_adapter_dialog.py index e616016..9ffb721 100644 --- a/src/iartisanxl/modules/common/ip_adapter/ip_adapter_dialog.py +++ b/src/iartisanxl/modules/common/ip_adapter/ip_adapter_dialog.py @@ -1,6 +1,6 @@ from PyQt6.QtCore import QEvent, QSettings, Qt from PyQt6.QtGui import QColor, QCursor, QGuiApplication, QPixmap -from PyQt6.QtWidgets import QApplication, QComboBox, QHBoxLayout, QLabel, QSlider +from PyQt6.QtWidgets import QApplication, QComboBox, QFrame, QHBoxLayout, QLabel, QPushButton, QSlider from superqt import QDoubleSlider from iartisanxl.app.event_bus import EventBus @@ -8,6 +8,7 @@ from iartisanxl.buttons.color_button import ColorButton from iartisanxl.buttons.eyedropper_button import EyeDropperButton from iartisanxl.modules.common.dialogs.base_dialog import BaseDialog +from iartisanxl.modules.common.ip_adapter.advanced_widget import AdvancedWidget from iartisanxl.modules.common.ip_adapter.image_section_widget import ImageSectionWidget from iartisanxl.modules.common.ip_adapter.ip_adapter_data_object import IPAdapterDataObject from iartisanxl.modules.common.ip_adapter.mask_section_widget import MaskSectionWidget @@ -53,15 +54,25 @@ def init_ui(self): self.type_combo.addItem("IP Adapter Composition", "ip_plus_composition_sdxl") top_layout.addWidget(self.type_combo) + self.scale_frame = QFrame() + self.scale_frame.setObjectName("main_scale_frame") + scale_layout = QHBoxLayout() adapter_scale_label = QLabel("Adapter scale:") - top_layout.addWidget(adapter_scale_label) + scale_layout.addWidget(adapter_scale_label) self.adapter_scale_slider = QDoubleSlider(Qt.Orientation.Horizontal) self.adapter_scale_slider.setRange(0.0, 1.0) self.adapter_scale_slider.setValue(self.adapter_scale) self.adapter_scale_slider.valueChanged.connect(self.on_adapter_scale_changed) - top_layout.addWidget(self.adapter_scale_slider) + scale_layout.addWidget(self.adapter_scale_slider) self.adapter_scale_value_label = QLabel(f"{self.adapter_scale}") - top_layout.addWidget(self.adapter_scale_value_label) + scale_layout.addWidget(self.adapter_scale_value_label) + self.scale_frame.setLayout(scale_layout) + top_layout.addWidget(self.scale_frame) + + advanced_button = QPushButton("Advanced") + advanced_button.setFixedWidth(80) + advanced_button.clicked.connect(self.on_advanced_clicked) + top_layout.addWidget(advanced_button) self.main_layout.addLayout(top_layout) @@ -112,6 +123,12 @@ def init_ui(self): self.mask_section_widget.setVisible(False) self.main_layout.addWidget(self.mask_section_widget) + self.advanced_widget = AdvancedWidget(self.adapter) + self.advanced_widget.setVisible(False) + self.advanced_widget.advanced_canceled.connect(self.on_cancel_advanced) + self.advanced_widget.granular_enabled.connect(self.on_granular) + self.main_layout.addWidget(self.advanced_widget) + self.main_layout.setStretch(0, 0) self.main_layout.setStretch(1, 0) self.main_layout.setStretch(2, 1) @@ -184,13 +201,14 @@ def connect_mask_editor(self): def on_add_mask_clicked(self): self.connect_mask_editor() self.image_section_widget.hide() + self.advanced_widget.hide() self.mask_section_widget.show() def on_mask_saved(self, thumb_pixmap: QPixmap): self.connect_image_editor() self.image_section_widget.ip_mask_item.set_pixmap(thumb_pixmap) - self.image_section_widget.show() self.mask_section_widget.hide() + self.image_section_widget.show() self.image_section_widget.add_mask_button.setText("Edit mask") def on_cancel_mask(self): @@ -219,3 +237,17 @@ def eventFilter(self, obj, event): self.color_button.set_color(rgb_color) return True return super().eventFilter(obj, event) + + def on_advanced_clicked(self): + self.mask_section_widget.hide() + self.image_section_widget.hide() + self.advanced_widget.show() + + def on_cancel_advanced(self): + self.mask_section_widget.hide() + self.advanced_widget.hide() + self.image_section_widget.show() + + def on_granular(self, state): + self.scale_frame.setDisabled(state) + self.adapter.granular_scale_enabled = state diff --git a/src/iartisanxl/theme/stylesheet.qss b/src/iartisanxl/theme/stylesheet.qss index 79e02ed..03ca1a0 100644 --- a/src/iartisanxl/theme/stylesheet.qss +++ b/src/iartisanxl/theme/stylesheet.qss @@ -515,4 +515,16 @@ MaskWidget ImageEditor { IpMaskItem QLabel#mask_label { border: 1px solid #848485; +} + +IPAdapterDialog AdvancedWidget QCheckBox { + background: #181b28; +} + +IPAdapterDialog QFrame:disabled { + color: #585858; +} + +IPAdapterDialog AdvancedWidget QFrame#block_frame { + border: 1px solid #848485; } \ No newline at end of file diff --git a/src/iartisanxl/threads/node_graph_thread.py b/src/iartisanxl/threads/node_graph_thread.py index dbcecbe..c65af4d 100644 --- a/src/iartisanxl/threads/node_graph_thread.py +++ b/src/iartisanxl/threads/node_graph_thread.py @@ -349,7 +349,11 @@ def run(self): # noqa: C901 if len(added_ip_adapters) > 0: for ip_adapter in added_ip_adapters: ip_adapter_node = IPAdapterNode( - ip_adapter.type_index, ip_adapter.adapter_type, ip_adapter.ip_adapter_scale + ip_adapter.type_index, + ip_adapter.adapter_type, + ip_adapter.ip_adapter_scale, + ip_adapter.granular_scale_enabled, + ip_adapter.granular_scale, ) self.node_graph.add_node(ip_adapter_node) ip_adapter.node_id = ip_adapter_node.id @@ -388,8 +392,10 @@ def run(self): # noqa: C901 if len(modified_ip_adapters) > 0: for ip_adapter in modified_ip_adapters: ip_adapter_node = self.node_graph.get_node(ip_adapter.node_id) + type_changed = False if ip_adapter.type_index != ip_adapter_node.type_index: + type_changed = True # disconnect old model ip_adapter_model_node = self.get_ip_adapter_model(ip_adapter_node.adapter_type) ip_adapter_node.disconnect("ip_adapter_model", ip_adapter_model_node, "ip_adapter_model") @@ -400,7 +406,13 @@ def run(self): # noqa: C901 # update rest of params ip_adapter_node.update_adapter( - ip_adapter.type_index, ip_adapter.adapter_type, ip_adapter.enabled, ip_adapter.ip_adapter_scale + ip_adapter.type_index, + ip_adapter.adapter_type, + ip_adapter.enabled, + ip_adapter.ip_adapter_scale, + ip_adapter.granular_scale_enabled, + ip_adapter.granular_scale, + type_changed, ) if ip_adapter.mask_image is not None: