@@ -44,7 +44,7 @@ def get_dummy_components(self):
44
44
torch .manual_seed (0 )
45
45
unet = UNet2DConditionModel (
46
46
block_out_channels = (32 , 64 ),
47
- layers_per_block = 2 ,
47
+ layers_per_block = 1 ,
48
48
sample_size = 32 ,
49
49
in_channels = 4 ,
50
50
out_channels = 4 ,
@@ -111,7 +111,7 @@ def get_dummy_inputs(self, device, seed=0):
111
111
"prompt" : "a cat and a frog" ,
112
112
"token_indices" : [2 , 5 ],
113
113
"generator" : generator ,
114
- "num_inference_steps" : 2 ,
114
+ "num_inference_steps" : 1 ,
115
115
"guidance_scale" : 6.0 ,
116
116
"output_type" : "numpy" ,
117
117
"max_iter_to_alter" : 2 ,
@@ -132,13 +132,18 @@ def test_inference(self):
132
132
image_slice = image [0 , - 3 :, - 3 :, - 1 ]
133
133
134
134
self .assertEqual (image .shape , (1 , 64 , 64 , 3 ))
135
- expected_slice = np .array ([0.5743 , 0.6081 , 0.4975 , 0.5021 , 0.5441 , 0.4699 , 0.4988 , 0.4841 , 0.4851 ])
135
+ expected_slice = np .array (
136
+ [0.63905364 , 0.62897307 , 0.48599017 , 0.5133624 , 0.5550048 , 0.45769516 , 0.50326973 , 0.5023139 , 0.45384496 ]
137
+ )
136
138
max_diff = np .abs (image_slice .flatten () - expected_slice ).max ()
137
139
self .assertLessEqual (max_diff , 1e-3 )
138
140
139
141
def test_inference_batch_consistent (self ):
140
142
# NOTE: Larger batch sizes cause this test to timeout, only test on smaller batches
141
- self ._test_inference_batch_consistent (batch_sizes = [2 , 4 ])
143
+ self ._test_inference_batch_consistent (batch_sizes = [1 , 2 ])
144
+
145
+ def test_inference_batch_single_identical (self ):
146
+ self ._test_inference_batch_single_identical (batch_size = 2 )
142
147
143
148
144
149
@require_torch_gpu
0 commit comments