Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
94 commits
Select commit Hold shift + click to select a range
bf7d24f
Core implementation of ControlNet and MultiControlNet.
GreggHelt2 Apr 29, 2023
a83d412
Added support for ControlNet and MultiControlNet to legacy non-nodal …
GreggHelt2 Apr 29, 2023
ddc9695
Added example of using ControlNet with legacy Txt2Img generator
GreggHelt2 Apr 29, 2023
f2f0f0f
Resolving rebase conflict
GreggHelt2 Apr 30, 2023
29ae468
Added first controlnet preprocessor node for canny edge detection.
GreggHelt2 Apr 30, 2023
d7b1d9c
Initial port of controlnet node support from generator-based TextToIm…
GreggHelt2 Apr 30, 2023
638d408
Switching to ControlField for output from controlnet nodes.
GreggHelt2 May 4, 2023
f09c82f
Resolving conflicts in rebase to origin/main
GreggHelt2 May 4, 2023
1f929f2
Refactored ControlNet nodes so they subclass from PreprocessedControl…
GreggHelt2 May 4, 2023
7bfaa09
changes to base class for controlnet nodes
GreggHelt2 May 5, 2023
a2c4d68
Added HED, LineArt, and OpenPose ControlNet nodes
GreggHelt2 May 5, 2023
b9f41cb
Added an additional "raw_processed_image" output port to controlnets,…
GreggHelt2 May 5, 2023
1efd8d9
Added more preprocessor nodes for:
GreggHelt2 May 5, 2023
955bd93
Prep for splitting pre-processor and controlnet nodes
GreggHelt2 May 6, 2023
a540a75
Refactored controlnet nodes: split out controlnet stuff into separate…
GreggHelt2 May 6, 2023
9f4a0df
Added resizing of controlnet image based on noise latent. Fixes a ten…
GreggHelt2 May 6, 2023
309efbb
More rebase repair.
GreggHelt2 May 7, 2023
b374d12
Added support for using multiple control nets. Unfortunately this bre…
GreggHelt2 May 9, 2023
a4fde5f
Fixed use of ControlNet control_weight parameter
GreggHelt2 May 9, 2023
0e85659
Fixed lint-ish formatting error
GreggHelt2 May 9, 2023
7c7d925
Core implementation of ControlNet and MultiControlNet.
GreggHelt2 Apr 29, 2023
ff3189a
Added first controlnet preprocessor node for canny edge detection.
GreggHelt2 Apr 30, 2023
ed183ce
Initial port of controlnet node support from generator-based TextToIm…
GreggHelt2 Apr 30, 2023
8daffee
Switching to ControlField for output from controlnet nodes.
GreggHelt2 May 4, 2023
babcb4c
Refactored controlnet node to output ControlField that bundles contro…
GreggHelt2 May 4, 2023
4963411
changes to base class for controlnet nodes
GreggHelt2 May 5, 2023
5972ee9
Added more preprocessor nodes for:
GreggHelt2 May 5, 2023
7d81c0a
Prep for splitting pre-processor and controlnet nodes
GreggHelt2 May 6, 2023
6500281
Refactored controlnet nodes: split out controlnet stuff into separate…
GreggHelt2 May 6, 2023
548657c
Added resizing of controlnet image based on noise latent. Fixes a ten…
GreggHelt2 May 6, 2023
af5d02b
Cleaning up TextToLatent arg testing
GreggHelt2 May 11, 2023
61b8965
Cleaning up mistakes after rebase.
GreggHelt2 May 11, 2023
c04e4c5
Removed last bits of dtype and and device hardwiring from controlnet …
GreggHelt2 May 11, 2023
13fa873
Refactored ControNet support to consolidate multiple parameters into …
GreggHelt2 May 12, 2023
b5c040c
Added support for specifying which step iteration to start using
GreggHelt2 May 12, 2023
6613181
Cleaning up prior to submitting ControlNet PR. Mostly turning off dia…
GreggHelt2 May 12, 2023
ed7e387
Added dependency on controlnet-aux v0.0.3
GreggHelt2 May 13, 2023
2f895c1
Commented out ZoeDetector. Will re-instate once there's a controlnet-…
GreggHelt2 May 13, 2023
410f872
Switched CotrolNet node modelname input from free text to default lis…
GreggHelt2 May 14, 2023
8655fc2
Fix to work with current stable release of controlnet_aux (v0.0.3). T…
GreggHelt2 May 17, 2023
640a321
Refactored most of controlnet code into its own method to declutter T…
GreggHelt2 May 18, 2023
8e818d7
Cleaning up after ControlNet refactor in TextToLatentsInvocation
GreggHelt2 May 18, 2023
80901dd
Extended node-based ControlNet support to LatentsToLatentsInvocation.
GreggHelt2 May 18, 2023
f7cc58b
chore(ui): regen api client
GreggHelt2 May 22, 2023
89b4ff2
fix(ui): add value to conditioning field
psychedelicious May 19, 2023
de83e19
fix(ui): add control field type
psychedelicious May 19, 2023
98abc48
fix(ui): fix node ui type hints
psychedelicious May 19, 2023
dbefbbe
fix(nodes): controlnet input accepts list or single controlnet
psychedelicious May 19, 2023
93035e3
Moved to controlnet_aux v0.0.4, reinstated Zoe controlnet preprocesso…
GreggHelt2 May 23, 2023
dc7cac1
Core implementation of ControlNet and MultiControlNet.
GreggHelt2 May 26, 2023
6815f88
Added first controlnet preprocessor node for canny edge detection.
GreggHelt2 Apr 30, 2023
36e5aac
Switching to ControlField for output from controlnet nodes.
GreggHelt2 May 4, 2023
e978347
Resolving conflicts in rebase to origin/main
GreggHelt2 May 4, 2023
69a235a
Refactored ControlNet nodes so they subclass from PreprocessedControl…
GreggHelt2 May 4, 2023
d9e921c
changes to base class for controlnet nodes
GreggHelt2 May 5, 2023
51e1e90
Added HED, LineArt, and OpenPose ControlNet nodes
GreggHelt2 May 5, 2023
952df4b
Added more preprocessor nodes for:
GreggHelt2 May 5, 2023
a7421aa
Prep for splitting pre-processor and controlnet nodes
GreggHelt2 May 6, 2023
a6ec5da
Refactored controlnet nodes: split out controlnet stuff into separate…
GreggHelt2 May 6, 2023
c2c4882
Added resizing of controlnet image based on noise latent. Fixes a ten…
GreggHelt2 May 6, 2023
e3b1dce
Added support for using multiple control nets. Unfortunately this bre…
GreggHelt2 May 9, 2023
f5eb6a0
Fixed use of ControlNet control_weight parameter
GreggHelt2 May 9, 2023
04a1617
Core implementation of ControlNet and MultiControlNet.
GreggHelt2 Apr 29, 2023
3ddaab5
Added first controlnet preprocessor node for canny edge detection.
GreggHelt2 Apr 30, 2023
5a5d1a5
Initial port of controlnet node support from generator-based TextToIm…
GreggHelt2 Apr 30, 2023
52ce93a
Switching to ControlField for output from controlnet nodes.
GreggHelt2 May 4, 2023
70dd3bb
Refactored controlnet node to output ControlField that bundles contro…
GreggHelt2 May 4, 2023
862e6f9
changes to base class for controlnet nodes
GreggHelt2 May 5, 2023
60b228b
Added more preprocessor nodes for:
GreggHelt2 May 5, 2023
4dcebdc
Prep for splitting pre-processor and controlnet nodes
GreggHelt2 May 6, 2023
3bedc04
Refactored controlnet nodes: split out controlnet stuff into separate…
GreggHelt2 May 6, 2023
1739740
Added resizing of controlnet image based on noise latent. Fixes a ten…
GreggHelt2 May 6, 2023
abcee6c
Cleaning up TextToLatent arg testing
GreggHelt2 May 11, 2023
bc37242
Cleaning up mistakes after rebase.
GreggHelt2 May 11, 2023
5a4714a
Removed last bits of dtype and and device hardwiring from controlnet …
GreggHelt2 May 11, 2023
b9a085a
Refactored ControNet support to consolidate multiple parameters into …
GreggHelt2 May 12, 2023
92a18d7
Added support for specifying which step iteration to start using
GreggHelt2 May 12, 2023
f3c5d11
Cleaning up prior to submitting ControlNet PR. Mostly turning off dia…
GreggHelt2 May 12, 2023
f806374
Commented out ZoeDetector. Will re-instate once there's a controlnet-…
GreggHelt2 May 13, 2023
71fbefe
Switched CotrolNet node modelname input from free text to default lis…
GreggHelt2 May 14, 2023
1c5afe0
Fix to work with current stable release of controlnet_aux (v0.0.3). T…
GreggHelt2 May 17, 2023
1976056
Refactored most of controlnet code into its own method to declutter T…
GreggHelt2 May 18, 2023
4713ce4
Cleaning up after ControlNet refactor in TextToLatentsInvocation
GreggHelt2 May 18, 2023
c8883f5
Extended node-based ControlNet support to LatentsToLatentsInvocation.
GreggHelt2 May 18, 2023
291e542
chore(ui): regen api client
psychedelicious May 19, 2023
73cae3a
fix(ui): fix node ui type hints
psychedelicious May 19, 2023
f146cc8
fix(nodes): controlnet input accepts list or single controlnet
psychedelicious May 19, 2023
450b7e9
Added Mediapipe image processor for use as ControlNet preprocessor.
GreggHelt2 May 23, 2023
00c8a52
Fixed bug where MediapipFaceProcessorInvocation was ignoring max_face…
GreggHelt2 May 24, 2023
095f66c
Added nodes for float params: ParamFloatInvocation and FloatCollectio…
GreggHelt2 May 26, 2023
7cb5079
Added mediapipe install requirement. Should be able to remove once co…
GreggHelt2 May 26, 2023
e402979
Added float to FIELD_TYPE_MAP ins constants.ts
GreggHelt2 May 26, 2023
fd5b73e
Progress toward improvement in fieldTemplateBuilder.ts getFieldType()
GreggHelt2 May 26, 2023
a4b0140
Fixed controlnet preprocessors and controlnet handling in TextToLaten…
GreggHelt2 May 26, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions invokeai/app/invocations/baseinvocation.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class UIConfig(TypedDict, total=False):
"image",
"latents",
"model",
"control",
],
]
tags: List[str]
Expand Down
8 changes: 8 additions & 0 deletions invokeai/app/invocations/collections.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,14 @@ class IntCollectionOutput(BaseInvocationOutput):
# Outputs
collection: list[int] = Field(default=[], description="The int collection")

class FloatCollectionOutput(BaseInvocationOutput):
"""A collection of floats"""

type: Literal["float_collection"] = "float_collection"

# Outputs
collection: list[float] = Field(default=[], description="The float collection")


class RangeInvocation(BaseInvocation):
"""Creates a range of numbers from start to stop with step"""
Expand Down
428 changes: 428 additions & 0 deletions invokeai/app/invocations/controlnet_image_processors.py

Large diffs are not rendered by default.

27 changes: 25 additions & 2 deletions invokeai/app/invocations/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@
from typing import Literal, Optional, Union, get_args

import numpy as np
from diffusers import ControlNetModel
from torch import Tensor
import torch

from pydantic import BaseModel, Field

Expand Down Expand Up @@ -58,6 +60,9 @@ class TextToImageInvocation(BaseInvocation, SDImageInvocation):
cfg_scale: float = Field(default=7.5, ge=1, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
progress_images: bool = Field(default=False, description="Whether or not to produce progress images during generation", )
control_model: Optional[str] = Field(default=None, description="The control model to use")
control_image: Optional[ImageField] = Field(default=None, description="The processed control image")
# fmt: on

# TODO: pass this an emitter method or something? or a session for dispatching?
Expand All @@ -78,17 +83,35 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
# Handle invalid model parameter
model = choose_model(context.services.model_manager, self.model)

# loading controlnet image (currently requires pre-processed image)
control_image = (
None if self.control_image is None
else context.services.images.get(
self.control_image.image_type, self.control_image.image_name
)
)
# loading controlnet model
if (self.control_model is None or self.control_model==''):
control_model = None
else:
# FIXME: change this to dropdown menu?
# FIXME: generalize so don't have to hardcode torch_dtype and device
control_model = ControlNetModel.from_pretrained(self.control_model,
torch_dtype=torch.float16).to("cuda")

# Get the source node id (we are invoking the prepared node)
graph_execution_state = context.services.graph_execution_manager.get(
context.graph_execution_state_id
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

outputs = Txt2Img(model).generate(
txt2img = Txt2Img(model, control_model=control_model)
outputs = txt2img.generate(
prompt=self.prompt,
step_callback=partial(self.dispatch_progress, context, source_node_id),
control_image=control_image,
**self.dict(
exclude={"prompt"}
exclude={"prompt", "control_image" }
), # Shorthand for passing all of the parameters above manually
)
# Outputs is an infinite iterator that will return a new InvokeAIGeneratorOutput object
Expand Down
125 changes: 116 additions & 9 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
# Copyright (c) 2023 Kyle Schouviller (https://github.com/kyle0654)

import random
from typing import Literal, Optional, Union
import einops
from typing import Literal, Optional, Union, List

from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_controlnet import MultiControlNetModel

from pydantic import BaseModel, Field, validator
import torch

Expand All @@ -11,14 +14,18 @@
from invokeai.app.util.misc import SEED_MAX, get_random_seed

from invokeai.app.util.step_callback import stable_diffusion_step_callback
from .controlnet_image_processors import ControlField

from ...backend.model_management.model_manager import ModelManager
from ...backend.util.devices import choose_torch_device, torch_dtype
from ...backend.stable_diffusion.diffusion.shared_invokeai_diffusion import PostprocessingSettings
from ...backend.image_util.seamless import configure_model_padding
from ...backend.prompting.conditioning import get_uc_and_c_and_ec

from ...backend.stable_diffusion.diffusers_pipeline import ConditioningData, StableDiffusionGeneratorPipeline, image_resized_to_grid_as_tensor
from ...backend.stable_diffusion.schedulers import SCHEDULER_MAP
from ...backend.stable_diffusion.diffusers_pipeline import ControlNetData

from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext, InvocationConfig
import numpy as np
from ..services.image_file_storage import ImageType
Expand All @@ -28,7 +35,7 @@
from ...backend.stable_diffusion import PipelineIntermediateState
from diffusers.schedulers import SchedulerMixin as Scheduler
import diffusers
from diffusers import DiffusionPipeline
from diffusers import DiffusionPipeline, ControlNetModel


class LatentsField(BaseModel):
Expand Down Expand Up @@ -84,13 +91,13 @@ def build_noise_output(latents_name: str, latents: torch.Tensor):

def get_scheduler(scheduler_name:str, model: StableDiffusionGeneratorPipeline)->Scheduler:
scheduler_class, scheduler_extra_config = SCHEDULER_MAP.get(scheduler_name, SCHEDULER_MAP['ddim'])

scheduler_config = model.scheduler.config
if "_backup" in scheduler_config:
scheduler_config = scheduler_config["_backup"]
scheduler_config = {**scheduler_config, **scheduler_extra_config, "_backup": scheduler_config}
scheduler = scheduler_class.from_config(scheduler_config)

# hack copied over from generate.py
if not hasattr(scheduler, 'uses_inpainting_model'):
scheduler.uses_inpainting_model = lambda: False
Expand Down Expand Up @@ -169,6 +176,7 @@ class TextToLatentsInvocation(BaseInvocation):
cfg_scale: float = Field(default=7.5, gt=0, description="The Classifier-Free Guidance, higher values may result in a result closer to the prompt", )
scheduler: SAMPLER_NAME_VALUES = Field(default="euler", description="The scheduler to use" )
model: str = Field(default="", description="The model to use (currently ignored)")
control: Union[ControlField, list[ControlField]] = Field(default=None, description="The control to use")
# seamless: bool = Field(default=False, description="Whether or not to generate an image that can tile without seams", )
# seamless_axes: str = Field(default="", description="The axes to tile the image on, 'x' and/or 'y'")
# fmt: on
Expand All @@ -179,7 +187,8 @@ class Config(InvocationConfig):
"ui": {
"tags": ["latents", "image"],
"type_hints": {
"model": "model"
"model": "model",
"control": "control",
}
},
}
Expand Down Expand Up @@ -238,6 +247,81 @@ def get_conditioning_data(self, context: InvocationContext, model: StableDiffusi
).add_scheduler_args_if_applicable(model.scheduler, eta=0.0)#ddim_eta)
return conditioning_data

def prep_control_data(self,
context: InvocationContext,
model: StableDiffusionGeneratorPipeline, # really only need model for dtype and device
control_input: List[ControlField],
latents_shape: List[int],
do_classifier_free_guidance: bool = True,
) -> List[ControlNetData]:
# assuming fixed dimensional scaling of 8:1 for image:latents
control_height_resize = latents_shape[2] * 8
control_width_resize = latents_shape[3] * 8
if control_input is None:
# print("control input is None")
control_list = None
elif isinstance(control_input, list) and len(control_input) == 0:
# print("control input is empty list")
control_list = None
elif isinstance(control_input, ControlField):
# print("control input is ControlField")
control_list = [control_input]
elif isinstance(control_input, list) and len(control_input) > 0 and isinstance(control_input[0], ControlField):
# print("control input is list[ControlField]")
control_list = control_input
else:
# print("input control is unrecognized:", type(self.control))
control_list = None
if (control_list is None):
control_data = None
# from above handling, any control that is not None should now be of type list[ControlField]
else:
# FIXME: add checks to skip entry if model or image is None
# and if weight is None, populate with default 1.0?
control_data = []
control_models = []
for control_info in control_list:
# handle control models
if ("," in control_info.control_model):
control_model_split = control_info.control_model.split(",")
control_name = control_model_split[0]
control_subfolder = control_model_split[1]
print("Using HF model subfolders")
print(" control_name: ", control_name)
print(" control_subfolder: ", control_subfolder)
control_model = ControlNetModel.from_pretrained(control_name,
subfolder=control_subfolder,
torch_dtype=model.unet.dtype).to(model.device)
else:
control_model = ControlNetModel.from_pretrained(control_info.control_model,
torch_dtype=model.unet.dtype).to(model.device)
control_models.append(control_model)
control_image_field = control_info.image
input_image = context.services.images.get_pil_image(control_image_field.image_type,
control_image_field.image_name)
# self.image.image_type, self.image.image_name
# FIXME: still need to test with different widths, heights, devices, dtypes
# and add in batch_size, num_images_per_prompt?
# and do real check for classifier_free_guidance?
# prepare_control_image should return torch.Tensor of shape(batch_size, 3, height, width)
control_image = model.prepare_control_image(
image=input_image,
do_classifier_free_guidance=do_classifier_free_guidance,
width=control_width_resize,
height=control_height_resize,
# batch_size=batch_size * num_images_per_prompt,
# num_images_per_prompt=num_images_per_prompt,
device=control_model.device,
dtype=control_model.dtype,
)
control_item = ControlNetData(model=control_model,
image_tensor=control_image,
weight=control_info.control_weight,
begin_step_percent=control_info.begin_step_percent,
end_step_percent=control_info.end_step_percent)
control_data.append(control_item)
# MultiControlNetModel has been refactored out, just need list[ControlNetData]
return control_data

def invoke(self, context: InvocationContext) -> LatentsOutput:
noise = context.services.latents.get(self.noise.latents_name)
Expand All @@ -252,14 +336,19 @@ def step_callback(state: PipelineIntermediateState):
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)

# TODO: Verify the noise is the right size
print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))

# TODO: Verify the noise is the right size
result_latents, result_attention_map_saver = model.latents_from_embeddings(
latents=torch.zeros_like(noise, dtype=torch_dtype(model.device)),
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
callback=step_callback
control_data=control_data, # list[ControlNetData]
callback=step_callback,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
Expand All @@ -285,7 +374,8 @@ class Config(InvocationConfig):
"ui": {
"tags": ["latents"],
"type_hints": {
"model": "model"
"model": "model",
"control": "control",
}
},
}
Expand All @@ -304,6 +394,11 @@ def step_callback(state: PipelineIntermediateState):
model = self.get_model(context.services.model_manager)
conditioning_data = self.get_conditioning_data(context, model)

print("type of control input: ", type(self.control))
control_data = self.prep_control_data(model=model, context=context, control_input=self.control,
latents_shape=noise.shape,
do_classifier_free_guidance=(self.cfg_scale >= 1.0))

# TODO: Verify the noise is the right size

initial_latents = latent if self.strength < 1.0 else torch.zeros_like(
Expand All @@ -318,6 +413,7 @@ def step_callback(state: PipelineIntermediateState):
noise=noise,
num_inference_steps=self.steps,
conditioning_data=conditioning_data,
control_data=control_data, # list[ControlNetData]
callback=step_callback
)

Expand Down Expand Up @@ -362,8 +458,14 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
np_image = model.decode_latents(latents)
image = model.numpy_to_pil(np_image)[0]

# what happened to metadata?
# metadata = context.services.metadata.build_metadata(
# session_id=context.graph_execution_state_id, node=self

torch.cuda.empty_cache()

# new (post Image service refactor) way of using services to save image
# and gnenerate unique image_name
image_dto = context.services.images.create(
image=image,
image_type=ImageType.RESULT,
Expand Down Expand Up @@ -414,6 +516,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
torch.cuda.empty_cache()

name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)

Expand Down Expand Up @@ -444,6 +547,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
torch.cuda.empty_cache()

name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, resized_latents)
context.services.latents.save(name, resized_latents)
return build_latents_output(latents_name=name, latents=resized_latents)

Expand All @@ -468,6 +572,9 @@ class Config(InvocationConfig):

@torch.no_grad()
def invoke(self, context: InvocationContext) -> LatentsOutput:
# image = context.services.images.get(
# self.image.image_type, self.image.image_name
# )
image = context.services.images.get_pil_image(
self.image.image_type, self.image.image_name
)
Expand All @@ -488,6 +595,6 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
)

name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)

9 changes: 9 additions & 0 deletions invokeai/app/invocations/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ class IntOutput(BaseInvocationOutput):
# fmt: on


class FloatOutput(BaseInvocationOutput):
"""A float output"""

# fmt: off
type: Literal["float_output"] = "float_output"
param: float = Field(default=None, description="The output float")
# fmt: on


class AddInvocation(BaseInvocation, MathInvocationConfig):
"""Adds two numbers"""

Expand Down
12 changes: 11 additions & 1 deletion invokeai/app/invocations/params.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Literal
from pydantic import Field
from .baseinvocation import BaseInvocation, BaseInvocationOutput, InvocationContext
from .math import IntOutput
from .math import IntOutput, FloatOutput

# Pass-through parameter nodes - used by subgraphs

Expand All @@ -16,3 +16,13 @@ class ParamIntInvocation(BaseInvocation):

def invoke(self, context: InvocationContext) -> IntOutput:
return IntOutput(a=self.a)

class ParamFloatInvocation(BaseInvocation):
"""A float parameter"""
#fmt: off
type: Literal["param_float"] = "param_float"
param: float = Field(default=0.0, description="The float value")
#fmt: on

def invoke(self, context: InvocationContext) -> FloatOutput:
return FloatOutput(param=self.param)
Loading