Skip to content
2 changes: 1 addition & 1 deletion ControlNeXt-SVD-v2-Training/models/controlnext_vid_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from diffusers.models.resnet import Downsample2D, ResnetBlock2D


class ControlNeXtSDVModel(ModelMixin, ConfigMixin):
class ControlNeXtSVDModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True

@register_to_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from models.controlnext_vid_svd import ControlNeXtSDVModel
from models.controlnext_vid_svd import ControlNeXtSVDModel

from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionControlNeXtModel,
controlnext: ControlNeXtSDVModel,
controlnext: ControlNeXtSVDModel,
scheduler: EulerDiscreteScheduler,
feature_extractor: CLIPImageProcessor,
):
Expand Down
4 changes: 2 additions & 2 deletions ControlNeXt-SVD-v2-Training/train_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,7 @@
from utils.vid_dataset import UBCFashion
from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel
from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt
from models.controlnext_vid_svd import ControlNeXtSDVModel
from models.controlnext_vid_svd import ControlNeXtSVDModel
import torch.nn as nn
import pdb
from diffusers.utils.torch_utils import randn_tensor
Expand Down Expand Up @@ -920,7 +920,7 @@ def main():
)

logger.info("Initializing controlnext weights from unet")
controlnext = ControlNeXtSDVModel()
controlnext = ControlNeXtSVDModel()

if args.controlnet_model_name_or_path:
logger.info("Loading existing controlnet weights")
Expand Down
2 changes: 1 addition & 1 deletion ControlNeXt-SVD-v2/models/controlnext_vid_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from diffusers.models.resnet import Downsample2D, ResnetBlock2D


class ControlNeXtSDVModel(ModelMixin, ConfigMixin):
class ControlNeXtSVDModel(ModelMixin, ConfigMixin):
_supports_gradient_checkpointing = True

@register_to_config
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from models.controlnext_vid_svd import ControlNeXtSDVModel
from models.controlnext_vid_svd import ControlNeXtSVDModel

from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionControlNeXtModel,
controlnext: ControlNeXtSDVModel,
controlnext: ControlNeXtSVDModel,
scheduler: EulerDiscreteScheduler,
feature_extractor: CLIPImageProcessor,
):
Expand Down
6 changes: 3 additions & 3 deletions ControlNeXt-SVD-v2/run_controlnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from PIL import Image
from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt
from models.controlnext_vid_svd import ControlNeXtSDVModel
from models.controlnext_vid_svd import ControlNeXtSVDModel
from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel
from transformers import CLIPVisionModelWithProjection
import re
Expand Down Expand Up @@ -221,7 +221,7 @@ def load_tensor(tensor_path):
subfolder="unet",
low_cpu_mem_usage=True,
)
controlnext = ControlNeXtSDVModel()
controlnext = ControlNeXtSVDModel()
controlnext.load_state_dict(load_tensor(args.controlnext_path))
unet.load_state_dict(load_tensor(args.unet_path), strict=False)

Expand Down Expand Up @@ -279,4 +279,4 @@ def load_tensor(tensor_path):
final_result,
validation_control_images[:num_frames],
args.output_dir,
fps=fps)
fps=fps)
2 changes: 1 addition & 1 deletion ControlNeXt-SVD/models/controlnext_vid_svd.py
Original file line number Diff line number Diff line change
Expand Up @@ -337,7 +337,7 @@ def forward(



class ControlNeXtSDVModel(ModelMixin, ConfigMixin):
class ControlNeXtSVDModel(ModelMixin, ConfigMixin):

_supports_gradient_checkpointing = True

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import PIL.Image
import torch
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
from models.controlnext_vid_svd import ControlNeXtSDVModel
from models.controlnext_vid_svd import ControlNeXtSVDModel

from diffusers.image_processor import VaeImageProcessor
from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
Expand Down Expand Up @@ -125,7 +125,7 @@ def __init__(
vae: AutoencoderKLTemporalDecoder,
image_encoder: CLIPVisionModelWithProjection,
unet: UNetSpatioTemporalConditionControlNeXtModel,
controlnext: ControlNeXtSDVModel,
controlnext: ControlNeXtSVDModel,
scheduler: EulerDiscreteScheduler,
feature_extractor: CLIPImageProcessor,
):
Expand Down
4 changes: 2 additions & 2 deletions ControlNeXt-SVD/run_controlnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import numpy as np
from PIL import Image
from pipeline.pipeline_stable_video_diffusion_controlnext import StableVideoDiffusionPipelineControlNeXt
from models.controlnext_vid_svd import ControlNeXtSDVModel
from models.controlnext_vid_svd import ControlNeXtSVDModel
from models.unet_spatio_temporal_condition_controlnext import UNetSpatioTemporalConditionControlNeXtModel
from transformers import CLIPVisionModelWithProjection
import re
Expand Down Expand Up @@ -202,7 +202,7 @@ def load_tensor(tensor_path):
low_cpu_mem_usage=True,
variant="fp16",
)
controlnext = ControlNeXtSDVModel()
controlnext = ControlNeXtSVDModel()
controlnext.load_state_dict(load_tensor(args.controlnext_path))
unet.load_state_dict(load_tensor(args.unet_path), strict=False)

Expand Down