Skip to content

Commit 0453d60

Browse files
fix(nodes): fix slatents and rlatents bugs
1 parent c4f4f8b commit 0453d60

File tree

1 file changed

+14
-23
lines changed

1 file changed

+14
-23
lines changed

invokeai/app/invocations/latent.py

Lines changed: 14 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -379,26 +379,24 @@ class ResizeLatentsInvocation(BaseInvocation):
379379
type: Literal["lresize"] = "lresize"
380380

381381
# Inputs
382-
latents: Optional[LatentsField] = Field(description="The latents to resize")
383-
width: int = Field(ge=64, multiple_of=8, description="The width to resize to")
384-
height: int = Field(ge=64, multiple_of=8, description="The height to resize to")
385-
downsample: int = Field(
386-
default=8, ge=1, description="The downsampling factor (leave at 8 for SD)"
387-
)
388-
mode: LATENTS_INTERPOLATION_MODE = Field(
389-
default="bilinear", description="The interpolation mode"
390-
)
382+
latents: Optional[LatentsField] = Field(description="The latents to resize")
383+
width: int = Field(ge=64, multiple_of=8, description="The width to resize to (px)")
384+
height: int = Field(ge=64, multiple_of=8, description="The height to resize to (px)")
385+
downsample: Optional[int] = Field(default=8, ge=1, description="The downsampling factor (leave at 8 for SD)")
386+
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
387+
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
391388

392389
def invoke(self, context: InvocationContext) -> LatentsOutput:
393390
latents = context.services.latents.get(self.latents.latents_name)
394-
# resizing
391+
395392
resized_latents = torch.nn.functional.interpolate(
396393
latents,
397394
size=(
398395
self.height // self.downsample,
399396
self.width // self.downsample,
400397
),
401398
mode=self.mode,
399+
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
402400
)
403401

404402
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699
@@ -415,27 +413,20 @@ class ScaleLatentsInvocation(BaseInvocation):
415413
type: Literal["lscale"] = "lscale"
416414

417415
# Inputs
418-
latents: Optional[LatentsField] = Field(description="The latents to resize")
419-
scale: int = Field(
420-
default=2, ge=1, description="The factor by which to scale the latents"
421-
)
422-
mode: LATENTS_INTERPOLATION_MODE = Field(
423-
default="bilinear", description="The interpolation mode"
424-
)
416+
latents: Optional[LatentsField] = Field(description="The latents to scale")
417+
scale_factor: float = Field(ge=0, description="The factor by which to scale the latents")
418+
mode: Optional[LATENTS_INTERPOLATION_MODE] = Field(default="bilinear", description="The interpolation mode")
419+
antialias: Optional[bool] = Field(default=False, description="Whether or not to antialias (applied in bilinear and bicubic modes only)")
425420

426421
def invoke(self, context: InvocationContext) -> LatentsOutput:
427422
latents = context.services.latents.get(self.latents.latents_name)
428423

429-
(_, _, h, w) = latents.size()
430-
431424
# resizing
432425
resized_latents = torch.nn.functional.interpolate(
433426
latents,
434-
size=(
435-
h * self.scale,
436-
w * self.scale,
437-
),
427+
scale_factor=self.scale_factor,
438428
mode=self.mode,
429+
antialias=self.antialias if self.mode in ["bilinear", "bicubic"] else False,
439430
)
440431

441432
# https://discuss.huggingface.co/t/memory-usage-by-later-pipeline-stages/23699

0 commit comments

Comments
 (0)