-
Notifications
You must be signed in to change notification settings - Fork 2.7k
Add compel node and conditioning field type #3265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Changes from all commits
Commits
Show all changes
15 commits
Select commit
Hold shift + click to select a range
d99a08a
Add compel node and conditioning field type
StAlKeR7779 8f460b9
Make latent generation nodes use conditions instead of prompt
StAlKeR7779 8cb2fa8
Restore log_tokenization check
StAlKeR7779 37916a2
Use textual inversion manager from pipeline, remove extra conditionin…
StAlKeR7779 89f1909
Update default graph
StAlKeR7779 d753cff
Undo debug message
StAlKeR7779 0b0068a
Merge branch 'main' into feat/compel_node
StAlKeR7779 56d3cbe
Merge branch 'main' into feat/compel_node
StAlKeR7779 7d221e2
Combine conditioning to one field(better fits for multiple type condi…
StAlKeR7779 1e6adf0
Fix default graph and test
StAlKeR7779 81ec476
Revert seed field addition
StAlKeR7779 85c3382
Merge branch 'main' into feat/compel_node
blessedcoolant 5012f61
Separate conditionings back to positive and negative
StAlKeR7779 58d7833
Review changes
StAlKeR7779 a80fe05
Rename compel node
StAlKeR7779 File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,245 @@ | ||
| from typing import Literal, Optional, Union | ||
| from pydantic import BaseModel, Field | ||
|
|
||
| from invokeai.app.invocations.util.choose_model import choose_model | ||
| from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig | ||
|
|
||
| from ...backend.util.devices import choose_torch_device, torch_dtype | ||
| from ...backend.stable_diffusion.diffusion import InvokeAIDiffuserComponent | ||
| from ...backend.stable_diffusion.textual_inversion_manager import TextualInversionManager | ||
|
|
||
| from compel import Compel | ||
| from compel.prompt_parser import ( | ||
| Blend, | ||
| CrossAttentionControlSubstitute, | ||
| FlattenedPrompt, | ||
| Fragment, | ||
| ) | ||
|
|
||
| from invokeai.backend.globals import Globals | ||
|
|
||
|
|
||
| class ConditioningField(BaseModel): | ||
| conditioning_name: Optional[str] = Field(default=None, description="The name of conditioning data") | ||
| class Config: | ||
| schema_extra = {"required": ["conditioning_name"]} | ||
|
|
||
|
|
||
| class CompelOutput(BaseInvocationOutput): | ||
| """Compel parser output""" | ||
|
|
||
| #fmt: off | ||
| type: Literal["compel_output"] = "compel_output" | ||
|
|
||
| conditioning: ConditioningField = Field(default=None, description="Conditioning") | ||
| #fmt: on | ||
|
|
||
|
|
||
| class CompelInvocation(BaseInvocation): | ||
| """Parse prompt using compel package to conditioning.""" | ||
|
|
||
| type: Literal["compel"] = "compel" | ||
|
|
||
| prompt: str = Field(default="", description="Prompt") | ||
| model: str = Field(default="", description="Model to use") | ||
|
|
||
| # Schema customisation | ||
| class Config(InvocationConfig): | ||
| schema_extra = { | ||
| "ui": { | ||
| "title": "Prompt (Compel)", | ||
| "tags": ["prompt", "compel"], | ||
| "type_hints": { | ||
| "model": "model" | ||
| } | ||
| }, | ||
| } | ||
|
|
||
| def invoke(self, context: InvocationContext) -> CompelOutput: | ||
|
|
||
| # TODO: load without model | ||
| model = choose_model(context.services.model_manager, self.model) | ||
| pipeline = model["model"] | ||
| tokenizer = pipeline.tokenizer | ||
| text_encoder = pipeline.text_encoder | ||
|
|
||
| # TODO: global? input? | ||
| #use_full_precision = precision == "float32" or precision == "autocast" | ||
| #use_full_precision = False | ||
|
|
||
| # TODO: redo TI when separate model loding implemented | ||
| #textual_inversion_manager = TextualInversionManager( | ||
| # tokenizer=tokenizer, | ||
| # text_encoder=text_encoder, | ||
| # full_precision=use_full_precision, | ||
| #) | ||
|
|
||
| def load_huggingface_concepts(concepts: list[str]): | ||
| pipeline.textual_inversion_manager.load_huggingface_concepts(concepts) | ||
|
|
||
| # apply the concepts library to the prompt | ||
| prompt_str = pipeline.textual_inversion_manager.hf_concepts_library.replace_concepts_with_triggers( | ||
| self.prompt, | ||
| lambda concepts: load_huggingface_concepts(concepts), | ||
| pipeline.textual_inversion_manager.get_all_trigger_strings(), | ||
| ) | ||
|
|
||
| # lazy-load any deferred textual inversions. | ||
| # this might take a couple of seconds the first time a textual inversion is used. | ||
| pipeline.textual_inversion_manager.create_deferred_token_ids_for_any_trigger_terms( | ||
| prompt_str | ||
| ) | ||
|
|
||
| compel = Compel( | ||
| tokenizer=tokenizer, | ||
| text_encoder=text_encoder, | ||
| textual_inversion_manager=pipeline.textual_inversion_manager, | ||
| dtype_for_device_getter=torch_dtype, | ||
| truncate_long_prompts=True, # TODO: | ||
| ) | ||
|
|
||
| # TODO: support legacy blend? | ||
|
|
||
| prompt: Union[FlattenedPrompt, Blend] = Compel.parse_prompt_string(prompt_str) | ||
|
|
||
| if getattr(Globals, "log_tokenization", False): | ||
| log_tokenization_for_prompt_object(prompt, tokenizer) | ||
|
|
||
| c, options = compel.build_conditioning_tensor_for_prompt_object(prompt) | ||
|
|
||
| # TODO: long prompt support | ||
| #if not self.truncate_long_prompts: | ||
| # [c, uc] = compel.pad_conditioning_tensors_to_same_length([c, uc]) | ||
|
|
||
| ec = InvokeAIDiffuserComponent.ExtraConditioningInfo( | ||
| tokens_count_including_eos_bos=get_max_token_count(tokenizer, prompt), | ||
| cross_attention_control_args=options.get("cross_attention_control", None), | ||
| ) | ||
|
|
||
| conditioning_name = f"{context.graph_execution_state_id}_{self.id}_conditioning" | ||
|
|
||
| # TODO: hacky but works ;D maybe rename latents somehow? | ||
| context.services.latents.set(conditioning_name, (c, ec)) | ||
|
|
||
| return CompelOutput( | ||
| conditioning=ConditioningField( | ||
| conditioning_name=conditioning_name, | ||
| ), | ||
| ) | ||
|
|
||
|
|
||
| def get_max_token_count( | ||
| tokenizer, prompt: Union[FlattenedPrompt, Blend], truncate_if_too_long=False | ||
| ) -> int: | ||
| if type(prompt) is Blend: | ||
| blend: Blend = prompt | ||
| return max( | ||
| [ | ||
| get_max_token_count(tokenizer, c, truncate_if_too_long) | ||
| for c in blend.prompts | ||
| ] | ||
| ) | ||
| else: | ||
| return len( | ||
| get_tokens_for_prompt_object(tokenizer, prompt, truncate_if_too_long) | ||
| ) | ||
|
|
||
|
|
||
| def get_tokens_for_prompt_object( | ||
| tokenizer, parsed_prompt: FlattenedPrompt, truncate_if_too_long=True | ||
| ) -> [str]: | ||
| if type(parsed_prompt) is Blend: | ||
| raise ValueError( | ||
| "Blend is not supported here - you need to get tokens for each of its .children" | ||
| ) | ||
|
|
||
| text_fragments = [ | ||
| x.text | ||
| if type(x) is Fragment | ||
| else ( | ||
| " ".join([f.text for f in x.original]) | ||
| if type(x) is CrossAttentionControlSubstitute | ||
| else str(x) | ||
| ) | ||
| for x in parsed_prompt.children | ||
| ] | ||
| text = " ".join(text_fragments) | ||
| tokens = tokenizer.tokenize(text) | ||
| if truncate_if_too_long: | ||
| max_tokens_length = tokenizer.model_max_length - 2 # typically 75 | ||
| tokens = tokens[0:max_tokens_length] | ||
| return tokens | ||
|
|
||
|
|
||
| def log_tokenization_for_prompt_object( | ||
| p: Union[Blend, FlattenedPrompt], tokenizer, display_label_prefix=None | ||
| ): | ||
| display_label_prefix = display_label_prefix or "" | ||
| if type(p) is Blend: | ||
| blend: Blend = p | ||
| for i, c in enumerate(blend.prompts): | ||
| log_tokenization_for_prompt_object( | ||
| c, | ||
| tokenizer, | ||
| display_label_prefix=f"{display_label_prefix}(blend part {i + 1}, weight={blend.weights[i]})", | ||
| ) | ||
| elif type(p) is FlattenedPrompt: | ||
| flattened_prompt: FlattenedPrompt = p | ||
| if flattened_prompt.wants_cross_attention_control: | ||
| original_fragments = [] | ||
| edited_fragments = [] | ||
| for f in flattened_prompt.children: | ||
| if type(f) is CrossAttentionControlSubstitute: | ||
| original_fragments += f.original | ||
| edited_fragments += f.edited | ||
| else: | ||
| original_fragments.append(f) | ||
| edited_fragments.append(f) | ||
|
|
||
| original_text = " ".join([x.text for x in original_fragments]) | ||
| log_tokenization_for_text( | ||
| original_text, | ||
| tokenizer, | ||
| display_label=f"{display_label_prefix}(.swap originals)", | ||
| ) | ||
| edited_text = " ".join([x.text for x in edited_fragments]) | ||
| log_tokenization_for_text( | ||
| edited_text, | ||
| tokenizer, | ||
| display_label=f"{display_label_prefix}(.swap replacements)", | ||
| ) | ||
| else: | ||
| text = " ".join([x.text for x in flattened_prompt.children]) | ||
| log_tokenization_for_text( | ||
| text, tokenizer, display_label=display_label_prefix | ||
| ) | ||
|
|
||
|
|
||
| def log_tokenization_for_text(text, tokenizer, display_label=None, truncate_if_too_long=False): | ||
| """shows how the prompt is tokenized | ||
| # usually tokens have '</w>' to indicate end-of-word, | ||
| # but for readability it has been replaced with ' ' | ||
| """ | ||
| tokens = tokenizer.tokenize(text) | ||
| tokenized = "" | ||
| discarded = "" | ||
| usedTokens = 0 | ||
| totalTokens = len(tokens) | ||
|
|
||
| for i in range(0, totalTokens): | ||
| token = tokens[i].replace("</w>", " ") | ||
| # alternate color | ||
| s = (usedTokens % 6) + 1 | ||
| if truncate_if_too_long and i >= tokenizer.model_max_length: | ||
| discarded = discarded + f"\x1b[0;3{s};40m{token}" | ||
| else: | ||
| tokenized = tokenized + f"\x1b[0;3{s};40m{token}" | ||
| usedTokens += 1 | ||
|
|
||
| if usedTokens > 0: | ||
| print(f'\n>> [TOKENLOG] Tokens {display_label or ""} ({usedTokens}):') | ||
| print(f"{tokenized}\x1b[0m") | ||
|
|
||
| if discarded != "": | ||
| print(f"\n>> [TOKENLOG] Tokens Discarded ({totalTokens - usedTokens}):") | ||
| print(f"{discarded}\x1b[0m") | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.