Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 14 additions & 3 deletions scripts/convert_sana_video_to_diffusers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
AutoencoderKLWan,
DPMSolverMultistepScheduler,
FlowMatchEulerDiscreteScheduler,
SanaVideoCausalTransformer3DModel,
SanaVideoPipeline,
SanaVideoTransformer3DModel,
UniPCMultistepScheduler,
Expand All @@ -24,7 +25,10 @@

CTX = init_empty_weights if is_accelerate_available else nullcontext

ckpt_ids = ["Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth"]
ckpt_ids = [
"Efficient-Large-Model/SANA-Video_2B_480p/checkpoints/SANA_Video_2B_480p.pth",
"Efficient-Large-Model/Sana-Video_2B_480p_LongLive/checkpoints/SANA_Video_2B_480p_LongLive.pth",
]
# https://github.com/NVlabs/Sana/blob/main/inference_video_scripts/inference_sana_video.py


Expand Down Expand Up @@ -98,6 +102,10 @@ def main(args):
else:
raise ValueError(f"Video size {args.video_size} is not supported.")

use_causal_linear_attn = False
if "Sana-Video_2B_480p_LongLive" in file_path:
use_causal_linear_attn = True

for depth in range(layer_num):
# Transformer blocks.
converted_state_dict[f"transformer_blocks.{depth}.scale_shift_table"] = state_dict.pop(
Expand Down Expand Up @@ -201,7 +209,10 @@ def main(args):
"rope_max_seq_len": 1024,
}

transformer = SanaVideoTransformer3DModel(**transformer_kwargs)
if use_causal_linear_attn:
transformer = SanaVideoCausalTransformer3DModel(**transformer_kwargs)
else:
transformer = SanaVideoTransformer3DModel(**transformer_kwargs)

transformer.load_state_dict(converted_state_dict, strict=True, assign=True)

Expand Down Expand Up @@ -314,7 +325,7 @@ def main(args):
choices=["flow-dpm_solver", "flow-euler", "uni-pc"],
help="Scheduler type to use.",
)
parser.add_argument("--task", default="t2v", type=str, required=True, help="Task to convert, t2v or i2v.")
parser.add_argument("--task", default="t2v", type=str, required=True, choices=["t2v", "i2v"], help="Task to convert, t2v or i2v.")
parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output pipeline.")
parser.add_argument("--save_full_pipeline", action="store_true", help="save all the pipeline elements in one.")
parser.add_argument("--dtype", default="fp32", type=str, choices=["fp32", "fp16", "bf16"], help="Weight dtype.")
Expand Down
5 changes: 4 additions & 1 deletion src/diffusers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@
"SanaControlNetModel",
"SanaTransformer2DModel",
"SanaVideoTransformer3DModel",
"SanaVideoCausalTransformer3DModel",
"SD3ControlNetModel",
"SD3MultiControlNetModel",
"SD3Transformer2DModel",
Expand Down Expand Up @@ -555,7 +556,7 @@
"SanaSprintImg2ImgPipeline",
"SanaSprintPipeline",
"SanaVideoPipeline",
"SanaVideoPipeline",
"LongSanaVideoPipeline",
"SemanticStableDiffusionPipeline",
"ShapEImg2ImgPipeline",
"ShapEPipeline",
Expand Down Expand Up @@ -968,6 +969,7 @@
QwenImageTransformer2DModel,
SanaControlNetModel,
SanaTransformer2DModel,
SanaVideoCausalTransformer3DModel,
SanaVideoTransformer3DModel,
SD3ControlNetModel,
SD3MultiControlNetModel,
Expand Down Expand Up @@ -1206,6 +1208,7 @@
LDMTextToImagePipeline,
LEditsPPPipelineStableDiffusion,
LEditsPPPipelineStableDiffusionXL,
LongSanaVideoPipeline,
LTXConditionPipeline,
LTXImageToVideoPipeline,
LTXLatentUpsamplePipeline,
Expand Down
2 changes: 2 additions & 0 deletions src/diffusers/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@
_import_structure["transformers.transformer_prx"] = ["PRXTransformer2DModel"]
_import_structure["transformers.transformer_qwenimage"] = ["QwenImageTransformer2DModel"]
_import_structure["transformers.transformer_sana_video"] = ["SanaVideoTransformer3DModel"]
_import_structure["transformers.transformer_sana_video_causal"] = ["SanaVideoCausalTransformer3DModel"]
_import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
_import_structure["transformers.transformer_skyreels_v2"] = ["SkyReelsV2Transformer3DModel"]
_import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
Expand Down Expand Up @@ -213,6 +214,7 @@
PRXTransformer2DModel,
QwenImageTransformer2DModel,
SanaTransformer2DModel,
SanaVideoCausalTransformer3DModel,
SanaVideoTransformer3DModel,
SD3Transformer2DModel,
SkyReelsV2Transformer3DModel,
Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
from .transformer_prx import PRXTransformer2DModel
from .transformer_qwenimage import QwenImageTransformer2DModel
from .transformer_sana_video import SanaVideoTransformer3DModel
from .transformer_sana_video_causal import SanaVideoCausalTransformer3DModel
from .transformer_sd3 import SD3Transformer2DModel
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
from .transformer_temporal import TransformerTemporalModel
Expand Down
Loading