diff --git a/scripts/generate_logits.py b/scripts/generate_logits.py index 89dce0e78d4e..99d46d6628a6 100644 --- a/scripts/generate_logits.py +++ b/scripts/generate_logits.py @@ -103,12 +103,12 @@ models = api.list_models(filter="diffusers") for mod in models: - if "google" in mod.author or mod.modelId == "CompVis/ldm-celebahq-256": - local_checkpoint = "/home/patrick/google_checkpoints/" + mod.modelId.split("/")[-1] + if "google" in mod.author or mod.id == "CompVis/ldm-celebahq-256": + local_checkpoint = "/home/patrick/google_checkpoints/" + mod.id.split("/")[-1] - print(f"Started running {mod.modelId}!!!") + print(f"Started running {mod.id}!!!") - if mod.modelId.startswith("CompVis"): + if mod.id.startswith("CompVis"): model = UNet2DModel.from_pretrained(local_checkpoint, subfolder="unet") else: model = UNet2DModel.from_pretrained(local_checkpoint) @@ -122,6 +122,6 @@ logits = model(noise, time_step).sample assert torch.allclose( - logits[0, 0, 0, :30], results["_".join("_".join(mod.modelId.split("/")).split("-"))], atol=1e-3 + logits[0, 0, 0, :30], results["_".join("_".join(mod.id.split("/")).split("-"))], atol=1e-3 ) - print(f"{mod.modelId} has passed successfully!!!") + print(f"{mod.id} has passed successfully!!!")