Skip to content

Commit 5656811

Browse files
committed
nit
1 parent 9b608aa commit 5656811

File tree

1 file changed

+6
-10
lines changed

1 file changed

+6
-10
lines changed

test/test_models.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -133,12 +133,10 @@ def get_export_import_copy(m):
133133
return imported
134134

135135
m_import = get_export_import_copy(m)
136-
with torch.no_grad():
137-
with freeze_rng_state():
138-
results = m(*args)
139-
with torch.no_grad():
140-
with freeze_rng_state():
141-
results_from_imported = m_import(*args)
136+
with torch.no_grad(), freeze_rng_state():
137+
results = m(*args)
138+
with torch.no_grad(), freeze_rng_state():
139+
results_from_imported = m_import(*args)
142140
tol = 3e-4
143141
torch.testing.assert_close(results, results_from_imported, atol=tol, rtol=tol)
144142

@@ -158,12 +156,10 @@ def get_export_import_copy(m):
158156

159157
sm = torch.jit.script(nn_module)
160158

161-
with torch.no_grad():
162-
with freeze_rng_state():
159+
with torch.no_grad(), freeze_rng_state():
163160
eager_out = nn_module(*args)
164161

165-
with torch.no_grad():
166-
with freeze_rng_state():
162+
with torch.no_grad(), freeze_rng_state():
167163
script_out = sm(*args)
168164
if unwrapper:
169165
script_out = unwrapper(script_out)

0 commit comments

Comments
 (0)