Skip to content

Commit 4cbb062

Browse files
using and no IEnumerable
1 parent ae50fa9 commit 4cbb062

File tree

2 files changed

+19
-19
lines changed

2 files changed

+19
-19
lines changed

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

+9-7
Original file line numberDiff line numberDiff line change
@@ -394,13 +394,15 @@ public void testBoundaryContinue()
394394
// Test that we differentiate both 'x' and 'y' correctly when x is a
395395
// predecessor of y.
396396

397-
var sess = self.cached_session().Single();
398-
var x = tf.constant(1.0);
399-
var y = x * 2.0;
400-
var z = y * 3.0;
401-
var grads = tf.gradients(z, new[] { x, y });
402-
self.assertTrue(all(grads.Select(x => x != null)));
403-
self.assertEqual(6.0, grads[0].eval());
397+
using (self.cached_session())
398+
{
399+
var x = tf.constant(1.0);
400+
var y = x * 2.0;
401+
var z = y * 3.0;
402+
var grads = tf.gradients(z, new[] { x, y });
403+
self.assertTrue(all(grads.Select(x => x != null)));
404+
self.assertEqual(6.0, grads[0].eval());
405+
}
404406
}
405407

406408
[Ignore("TODO")]

test/TensorFlowNET.Graph.UnitTest/PythonTest.cs

+10-12
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ public T evaluate<T>(Tensor tensor)
221221
}
222222

223223
///Returns a TensorFlow Session for use in executing tests.
224-
public IEnumerable<Session> cached_session(
224+
public Session cached_session(
225225
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
226226
{
227227
// This method behaves differently than self.session(): for performance reasons
@@ -267,9 +267,8 @@ public IEnumerable<Session> cached_session(
267267
{
268268
var sess = self._get_cached_session(
269269
graph, config, force_gpu, crash_if_inconsistent_args: true);
270-
var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
271-
return cached;
272-
270+
using var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
271+
return cached;
273272
}
274273
}
275274

@@ -318,13 +317,12 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
318317
return s.as_default();
319318
}
320319

321-
private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
320+
private Session _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
322321
{
323322
// Set the session and its graph to global default and constrain devices."""
324-
// if context.executing_eagerly():
325-
// yield None
326-
// else:
327-
{
323+
if (tf.executing_eagerly())
324+
return null;
325+
else {
328326
sess.graph.as_default();
329327
sess.as_default();
330328
{
@@ -340,13 +338,13 @@ private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bo
340338
using (sess.graph.device(gpu_name)) {
341339
yield return sess;
342340
}*/
343-
yield return sess;
341+
return sess;
344342
}
345343
else if (use_gpu)
346-
yield return sess;
344+
return sess;
347345
else
348346
using (sess.graph.device("/device:CPU:0"))
349-
yield return sess;
347+
return sess;
350348
}
351349

352350
}

0 commit comments

Comments
 (0)