Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 81 additions & 1 deletion tests/test_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import numpy as np
import torch

from diffusers import DDIMScheduler, DDPMScheduler, PNDMScheduler, ScoreSdeVeScheduler
from diffusers import DDIMScheduler, DDPMScheduler, LMSDiscreteScheduler, PNDMScheduler, ScoreSdeVeScheduler


torch.backends.cuda.matmul.allow_tf32 = False
Expand Down Expand Up @@ -853,3 +853,83 @@ def test_step_shape(self):

self.assertEqual(output_0.shape, sample.shape)
self.assertEqual(output_0.shape, output_1.shape)


class LMSDiscreteSchedulerTest(SchedulerCommonTest):
scheduler_classes = (LMSDiscreteScheduler,)
num_inference_steps = 10

def get_scheduler_config(self, **kwargs):
config = {
"num_train_timesteps": 1100,
"beta_start": 0.0001,
"beta_end": 0.02,
"beta_schedule": "linear",
"trained_betas": None,
"tensor_format": "pt",
}

config.update(**kwargs)
return config

def test_timesteps(self):
for timesteps in [10, 50, 100, 1000]:
self.check_over_configs(num_train_timesteps=timesteps)

def test_betas(self):
for beta_start, beta_end in zip([0.0001, 0.001, 0.01, 0.1], [0.002, 0.02, 0.2, 2]):
self.check_over_configs(beta_start=beta_start, beta_end=beta_end)

def test_schedules(self):
for schedule in ["linear", "scaled_linear"]:
self.check_over_configs(beta_schedule=schedule)

def test_time_indices(self):
for t in [0, 500, 800]:
self.check_over_forward(time_step=t)

def test_pytorch_equal_numpy(self):
for scheduler_class in self.scheduler_classes:
sample_pt = self.dummy_sample
residual_pt = 0.1 * sample_pt

sample = sample_pt.numpy()
residual = 0.1 * sample

scheduler_config = self.get_scheduler_config()
scheduler_config["tensor_format"] = "np"
scheduler = scheduler_class(**scheduler_config)

scheduler_config["tensor_format"] = "pt"
scheduler_pt = scheduler_class(**scheduler_config)

scheduler.set_timesteps(self.num_inference_steps)
scheduler_pt.set_timesteps(self.num_inference_steps)

output = scheduler.step(residual, 1, sample).prev_sample
output_pt = scheduler_pt.step(residual_pt, 1, sample_pt).prev_sample
assert np.sum(np.abs(output - output_pt.numpy())) < 1e-4, "Scheduler outputs are not identical"

def test_full_loop_no_noise(self):
scheduler_class = self.scheduler_classes[0]
scheduler_config = self.get_scheduler_config()
scheduler = scheduler_class(**scheduler_config)

scheduler.set_timesteps(self.num_inference_steps)

model = self.dummy_model()
sample = self.dummy_sample_deter * scheduler.sigmas[0]

for i, t in enumerate(scheduler.timesteps):
sample = sample / ((scheduler.sigmas[i] ** 2 + 1) ** 0.5)

model_output = model(sample, t)

output = scheduler.step(model_output, i, sample)
sample = output.prev_sample

result_sum = torch.sum(torch.abs(sample))
result_mean = torch.mean(torch.abs(sample))

assert abs(result_sum.item() - 1006.388) < 1e-2
assert abs(result_mean.item() - 1.31) < 1e-3