Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions invokeai/app/invocations/compel.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,8 @@ def _lora_loader():
cross_attention_control_args=options.get(
"cross_attention_control", None),)

c = c.detach().to("cpu")

conditioning_data = ConditioningFieldData(
conditionings=[
BasicConditioningInfo(
Expand Down Expand Up @@ -230,6 +232,10 @@ def _lora_loader():
del tokenizer_info
del text_encoder_info

c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")

return c, c_pooled, None

def run_clip_compel(self, context, clip_field, prompt, get_pooled):
Expand Down Expand Up @@ -306,6 +312,10 @@ def _lora_loader():
del tokenizer_info
del text_encoder_info

c = c.detach().to("cpu")
if c_pooled is not None:
c_pooled = c_pooled.detach().to("cpu")

return c, c_pooled, ec

class SDXLCompelPromptInvocation(BaseInvocation, SDXLPromptInvocationBase):
Expand Down
9 changes: 5 additions & 4 deletions invokeai/app/invocations/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,13 +146,13 @@ def dispatch_progress(
source_node_id=source_node_id,
)

def get_conditioning(self, context):
def get_conditioning(self, context, unet):
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning

negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)

return (uc, c, extra_conditioning_info)

Expand Down Expand Up @@ -213,14 +213,15 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
)
source_node_id = graph_execution_state.prepared_source_mapping[self.id]

conditioning = self.get_conditioning(context)
scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)

with self.load_model_old_way(context, scheduler) as model:
conditioning = self.get_conditioning(context, model.context.model.unet)

outputs = Inpaint(model).generate(
conditioning=conditioning,
scheduler=scheduler,
Expand Down
33 changes: 25 additions & 8 deletions invokeai/app/invocations/latent.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,13 +167,14 @@ def get_conditioning_data(
self,
context: InvocationContext,
scheduler,
unet,
) -> ConditioningData:
positive_cond_data = context.services.latents.get(self.positive_conditioning.conditioning_name)
c = positive_cond_data.conditionings[0].embeds
c = positive_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)
extra_conditioning_info = positive_cond_data.conditionings[0].extra_conditioning

negative_cond_data = context.services.latents.get(self.negative_conditioning.conditioning_name)
uc = negative_cond_data.conditionings[0].embeds
uc = negative_cond_data.conditionings[0].embeds.to(device=unet.device, dtype=unet.dtype)

conditioning_data = ConditioningData(
unconditioned_embeddings=uc,
Expand All @@ -195,7 +196,7 @@ def get_conditioning_data(
eta=0.0, # ddim_eta

# for ancestral and sde schedulers
generator=torch.Generator(device=uc.device).manual_seed(0),
generator=torch.Generator(device=unet.device).manual_seed(0),
)
return conditioning_data

Expand Down Expand Up @@ -334,14 +335,16 @@ def _lora_loader():
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:

noise = noise.to(device=unet.device, dtype=unet.dtype)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)

pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)

control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
Expand All @@ -362,6 +365,7 @@ def _lora_loader():
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()

name = f'{context.graph_execution_state_id}__{self.id}'
Expand Down Expand Up @@ -424,14 +428,17 @@ def _lora_loader():
ModelPatcher.apply_lora_unet(unet_info.context.model, _lora_loader()),\
unet_info as unet:

noise = noise.to(device=unet.device, dtype=unet.dtype)
latent = latent.to(device=unet.device, dtype=unet.dtype)

scheduler = get_scheduler(
context=context,
scheduler_info=self.unet.scheduler,
scheduler_name=self.scheduler,
)

pipeline = self.create_pipeline(unet, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler)
conditioning_data = self.get_conditioning_data(context, scheduler, unet)

control_data = self.prep_control_data(
model=pipeline, context=context, control_input=self.control,
Expand Down Expand Up @@ -463,6 +470,7 @@ def _lora_loader():
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
result_latents = result_latents.to("cpu")
torch.cuda.empty_cache()

name = f'{context.graph_execution_state_id}__{self.id}'
Expand Down Expand Up @@ -503,6 +511,7 @@ def invoke(self, context: InvocationContext) -> ImageOutput:
)

with vae_info as vae:
latents = latents.to(vae.device)
if self.fp32:
vae.to(dtype=torch.float32)

Expand Down Expand Up @@ -590,13 +599,17 @@ class ResizeLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)

# TODO:
device=choose_torch_device()

resized_latents = torch.nn.functional.interpolate(
latents, size=(self.height // 8, self.width // 8),
latents.to(device), size=(self.height // 8, self.width // 8),
mode=self.mode, antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()

name = f"{context.graph_execution_state_id}__{self.id}"
Expand Down Expand Up @@ -624,14 +637,18 @@ class ScaleLatentsInvocation(BaseInvocation):
def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = context.services.latents.get(self.latents.latents_name)

# TODO:
device=choose_torch_device()

# resizing
resized_latents = torch.nn.functional.interpolate(
latents, scale_factor=self.scale_factor, mode=self.mode,
latents.to(device), scale_factor=self.scale_factor, mode=self.mode,
antialias=self.antialias
if self.mode in ["bilinear", "bicubic"] else False,
)

# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
resized_latents = resized_latents.to("cpu")
torch.cuda.empty_cache()

name = f"{context.graph_execution_state_id}__{self.id}"
Expand Down Expand Up @@ -722,6 +739,6 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
latents = latents.to(dtype=orig_dtype)

name = f"{context.graph_execution_state_id}__{self.id}"
# context.services.latents.set(name, latents)
latents = latents.to("cpu")
context.services.latents.save(name, latents)
return build_latents_output(latents_name=name, latents=latents)
2 changes: 1 addition & 1 deletion invokeai/app/invocations/noise.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def get_noise(
dtype=torch_dtype(device),
device=noise_device_type,
generator=generator,
).to(device)
).to("cpu")

return noise_tensor

Expand Down
6 changes: 4 additions & 2 deletions invokeai/app/invocations/sdxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)

with tqdm(total=self.steps) as progress_bar:
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down Expand Up @@ -351,7 +351,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:
add_time_ids = add_time_ids.to(device=unet.device, dtype=unet.dtype)
latents = latents.to(device=unet.device, dtype=unet.dtype)

with tqdm(total=self.steps) as progress_bar:
with tqdm(total=num_inference_steps) as progress_bar:
for i, t in enumerate(timesteps):
# expand the latents if we are doing classifier free guidance
#latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
Expand Down Expand Up @@ -415,6 +415,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

#################

latents = latents.to("cpu")
torch.cuda.empty_cache()

name = f'{context.graph_execution_state_id}__{self.id}'
Expand Down Expand Up @@ -651,6 +652,7 @@ def invoke(self, context: InvocationContext) -> LatentsOutput:

#################

latents = latents.to("cpu")
torch.cuda.empty_cache()

name = f'{context.graph_execution_state_id}__{self.id}'
Expand Down
3 changes: 2 additions & 1 deletion invokeai/backend/model_management/model_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@ def __init__(
:param sha_chunksize: Chunksize to use when calculating sha256 model hash
'''
self.model_infos: Dict[str, ModelBase] = dict()
self.lazy_offloading = lazy_offloading
# allow lazy offloading only when vram cache enabled
self.lazy_offloading = lazy_offloading and max_vram_cache_size > 0
self.precision: torch.dtype=precision
self.max_cache_size: float=max_cache_size
self.max_vram_cache_size: float=max_vram_cache_size
Expand Down