-
Notifications
You must be signed in to change notification settings - Fork 6.1k
[Tests] better determinism #3374
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
50ae297
44b0ad6
1d88907
9ef07c6
ba8f9c8
56ee5d0
8b1e927
a57abd9
30ee9e1
a1fc9fa
1684c11
df6c0ad
1760fbc
d709b19
9f40ef1
ae884b7
7738519
ba3a893
a7dfbea
4142669
dc564f7
4ca382d
0789933
f120d7a
202d76d
81e287a
6f9a6f0
6a19ce3
71b0782
ce3d25f
061a179
438353c
063a5b7
864e2bc
a344861
37fb81b
46495f9
5df5445
5c700aa
288c2cf
440f2ae
5cd316d
21b8f7a
99269f0
9f2616c
6538392
9f47481
306a9ce
0a863bc
1c89025
6e5e518
47c583a
8b9d5b8
a6a6532
89dd26b
acad10f
3176160
221f0eb
edd0837
a323939
fa50f12
0080889
19fce17
70d5de0
74d5bae
6c56f09
5cd391a
1b44420
12ec5c8
25525e3
d98e296
06f94bd
b78dee6
9805f15
2db2296
8008687
05f52b2
db9eef6
05612f0
08320a3
8c09cf0
3afc0c0
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -27,6 +27,7 @@ | |
|
||
logger = logging.get_logger(__name__) | ||
torch.backends.cuda.matmul.allow_tf32 = False | ||
torch.use_deterministic_algorithms(True) | ||
|
||
|
||
class Unet2DModelTests(ModelTesterMixin, unittest.TestCase): | ||
|
@@ -246,10 +247,6 @@ def test_output_pretrained_ve_mid(self): | |
model = UNet2DModel.from_pretrained("google/ncsnpp-celebahq-256") | ||
model.to(torch_device) | ||
|
||
torch.manual_seed(0) | ||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed_all(0) | ||
|
||
Comment on lines
-249
to
-252
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Not needed with |
||
batch_size = 4 | ||
num_channels = 3 | ||
sizes = (256, 256) | ||
|
@@ -262,7 +259,7 @@ def test_output_pretrained_ve_mid(self): | |
|
||
output_slice = output[0, -3:, -3:, -1].flatten().cpu() | ||
# fmt: off | ||
expected_output_slice = torch.tensor([-4836.2231, -6487.1387, -3816.7969, -7964.9253, -10966.2842, -20043.6016, 8137.0571, 2340.3499, 544.6114]) | ||
expected_output_slice = torch.tensor([-4842.8691, -6499.6631, -3800.1953, -7978.2686, -10980.7129, -20028.8535, 8148.2822, 2342.2905, 567.7608]) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. With PyTorch 2.0, this had to be changed. |
||
# fmt: on | ||
|
||
self.assertTrue(torch_all_close(output_slice, expected_output_slice, rtol=1e-2)) | ||
|
@@ -271,10 +268,6 @@ def test_output_pretrained_ve_large(self): | |
model = UNet2DModel.from_pretrained("fusing/ncsnpp-ffhq-ve-dummy-update") | ||
model.to(torch_device) | ||
|
||
torch.manual_seed(0) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. nice clean up! |
||
if torch.cuda.is_available(): | ||
torch.cuda.manual_seed_all(0) | ||
|
||
batch_size = 4 | ||
num_channels = 3 | ||
sizes = (32, 32) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -39,6 +39,7 @@ | |
|
||
logger = logging.get_logger(__name__) | ||
torch.backends.cuda.matmul.allow_tf32 = False | ||
torch.use_deterministic_algorithms(True) | ||
|
||
|
||
def create_lora_layers(model, mock_weights: bool = True): | ||
|
@@ -442,8 +443,8 @@ def test_lora_processors(self): | |
sample3 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample | ||
sample4 = model(**inputs_dict, cross_attention_kwargs={"scale": 0.5}).sample | ||
|
||
assert (sample1 - sample2).abs().max() < 1e-4 | ||
assert (sample3 - sample4).abs().max() < 1e-4 | ||
assert (sample1 - sample2).abs().max() < 3e-3 | ||
assert (sample3 - sample4).abs().max() < 3e-3 | ||
Comment on lines
+446
to
+447
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Explained in the PR description why I had to relax the tolerance. |
||
|
||
# sample 2 and sample 3 should be different | ||
assert (sample2 - sample3).abs().max() > 1e-4 | ||
|
@@ -587,7 +588,7 @@ def test_lora_on_off(self): | |
new_sample = model(**inputs_dict).sample | ||
|
||
assert (sample - new_sample).abs().max() < 1e-4 | ||
assert (sample - old_sample).abs().max() < 1e-4 | ||
assert (sample - old_sample).abs().max() < 3e-3 | ||
|
||
@unittest.skipIf( | ||
torch_device != "cuda" or not is_xformers_available(), | ||
|
@@ -642,7 +643,7 @@ def test_custom_diffusion_processors(self): | |
with torch.no_grad(): | ||
sample2 = model(**inputs_dict).sample | ||
|
||
assert (sample1 - sample2).abs().max() < 1e-4 | ||
assert (sample1 - sample2).abs().max() < 3e-3 | ||
|
||
def test_custom_diffusion_save_load(self): | ||
# enable deterministic behavior for gradient checkpointing | ||
|
@@ -677,7 +678,7 @@ def test_custom_diffusion_save_load(self): | |
assert (sample - new_sample).abs().max() < 1e-4 | ||
|
||
# custom diffusion and no custom diffusion should be the same | ||
assert (sample - old_sample).abs().max() < 1e-4 | ||
assert (sample - old_sample).abs().max() < 3e-3 | ||
|
||
@unittest.skipIf( | ||
torch_device != "cuda" or not is_xformers_available(), | ||
|
@@ -957,7 +958,7 @@ def test_compvis_sd_inpaint(self, seed, timestep, expected_slice): | |
output_slice = sample[-1, -2:, -2:, :2].flatten().float().cpu() | ||
expected_output_slice = torch.tensor(expected_slice) | ||
|
||
assert torch_all_close(output_slice, expected_output_slice, atol=1e-3) | ||
assert torch_all_close(output_slice, expected_output_slice, atol=3e-3) | ||
|
||
@parameterized.expand( | ||
[ | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
totally fine for me!