Skip to content

Commit c33a3a7

Browse files
cached_session for graph tests
1 parent 179c3f0 commit c33a3a7

File tree

3 files changed

+156
-16
lines changed

3 files changed

+156
-16
lines changed

test/TensorFlowNET.Graph.UnitTest/ControlFlowTest/WhileContextTestCase.cs

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using System;
3+
using System.Linq;
34
using Tensorflow;
45
using static Tensorflow.Binding;
56

@@ -29,7 +30,7 @@ private void _testWhileContextHelper(int maximum_iterations)
2930
var b = new Func<Tensor, Tensor>(x => math_ops.add(x, 1, name: "c"));
3031
//control_flow_ops.while_loop(
3132
// c, b, i , maximum_iterations: tf.constant(maximum_iterations));
32-
foreach (Operation op in sess.graph.get_operations())
33+
foreach (Operation op in sess.Single().graph.get_operations())
3334
{
3435
var control_flow_context = op._get_control_flow_context();
3536
/*if (control_flow_context != null)

test/TensorFlowNET.Graph.UnitTest/GradientTest/GradientTest.cs

+9-12
Original file line numberDiff line numberDiff line change
@@ -388,22 +388,19 @@ public void testBoundaryStop()
388388

389389
}
390390

391-
[Ignore("TODO")]
392391
[TestMethod]
393392
public void testBoundaryContinue()
394393
{
395-
//@test_util.run_v1_only("b/120545219")
396-
//def testBoundaryContinue(self):
397-
// # Test that we differentiate both 'x' and 'y' correctly when x is a
398-
// # predecessor of y.
399-
// with self.cached_session():
400-
// x = constant(1.0)
401-
// y = x * 2.0
402-
// z = y * 3.0
403-
// grads = gradients.gradients(z, [x, y])
404-
// self.assertTrue(all(x is not None for x in grads))
405-
// self.assertEqual(6.0, grads[0].eval())
394+
// Test that we differentiate both 'x' and 'y' correctly when x is a
395+
// predecessor of y.
406396

397+
self.cached_session();
398+
var x = tf.constant(1.0);
399+
var y = x * 2.0;
400+
var z = y * 3.0;
401+
var grads = tf.gradients(z, new[] { x, y });
402+
self.assertTrue(all(grads.Select(x => x != null)));
403+
self.assertEqual(6.0, grads[0].eval());
407404
}
408405

409406
[Ignore("TODO")]

test/TensorFlowNET.Graph.UnitTest/PythonTest.cs

+145-3
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
using System.Linq;
77
using Tensorflow;
88
using static Tensorflow.Binding;
9+
using OneOf.Types;
10+
using System.Collections.Generic;
911

1012
namespace TensorFlowNET.UnitTest
1113
{
@@ -139,6 +141,21 @@ public void assertProtoEquals(object toProto, object o)
139141

140142
#region tensor evaluation and test session
141143

144+
private Session _cached_session = null;
145+
private Graph _cached_graph = null;
146+
private object _cached_config = null;
147+
private bool _cached_force_gpu = false;
148+
149+
private void _ClearCachedSession()
150+
{
151+
if (self._cached_session != null)
152+
{
153+
self._cached_session.Dispose();
154+
self._cached_session = null;
155+
}
156+
}
157+
158+
142159
//protected object _eval_helper(Tensor[] tensors)
143160
//{
144161
// if (tensors == null)
@@ -203,10 +220,57 @@ public T evaluate<T>(Tensor tensor)
203220
}
204221
}
205222

206-
207-
public Session cached_session()
223+
///Returns a TensorFlow Session for use in executing tests.
224+
public IEnumerable<Session> cached_session(
225+
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false)
208226
{
209-
throw new NotImplementedException();
227+
// This method behaves differently than self.session(): for performance reasons
228+
// `cached_session` will by default reuse the same session within the same
229+
// test.The session returned by this function will only be closed at the end
230+
// of the test(in the TearDown function).
231+
232+
// Use the `use_gpu` and `force_gpu` options to control where ops are run.If
233+
// `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if
234+
// `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as
235+
// possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to
236+
// the CPU.
237+
238+
// Example:
239+
// python
240+
// class MyOperatorTest(test_util.TensorFlowTestCase) :
241+
// def testMyOperator(self):
242+
// with self.cached_session() as sess:
243+
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0]
244+
// result = MyOperator(valid_input).eval()
245+
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0]
246+
// invalid_input = [-1.0, 2.0, 7.0]
247+
// with self.assertRaisesOpError("negative input not supported"):
248+
// MyOperator(invalid_input).eval()
249+
250+
251+
// Args:
252+
// graph: Optional graph to use during the returned session.
253+
// config: An optional config_pb2.ConfigProto to use to configure the
254+
// session.
255+
// use_gpu: If True, attempt to run as many ops as possible on GPU.
256+
// force_gpu: If True, pin all ops to `/device:GPU:0`.
257+
258+
// Yields:
259+
// A Session object that should be used as a context manager to surround
260+
// the graph building and execution code in a test case.
261+
262+
263+
// TODO:
264+
// if context.executing_eagerly():
265+
// return self._eval_helper(tensors)
266+
// else:
267+
{
268+
var sess = self._get_cached_session(
269+
graph, config, force_gpu, crash_if_inconsistent_args: true);
270+
var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu);
271+
return cached;
272+
273+
}
210274
}
211275

212276
//Returns a TensorFlow Session for use in executing tests.
@@ -254,6 +318,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
254318
return s.as_default();
255319
}
256320

321+
private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu)
322+
{
323+
// Set the session and its graph to global default and constrain devices."""
324+
// if context.executing_eagerly():
325+
// yield None
326+
// else:
327+
{
328+
sess.graph.as_default();
329+
sess.as_default();
330+
{
331+
if (force_gpu)
332+
{
333+
// TODO:
334+
335+
// Use the name of an actual device if one is detected, or
336+
// '/device:GPU:0' otherwise
337+
/* var gpu_name = gpu_device_name();
338+
if (!gpu_name)
339+
gpu_name = "/device:GPU:0"
340+
using (sess.graph.device(gpu_name)) {
341+
yield return sess;
342+
}*/
343+
yield return sess;
344+
}
345+
else if (use_gpu)
346+
yield return sess;
347+
else
348+
using (sess.graph.device("/device:CPU:0"))
349+
yield return sess;
350+
}
351+
352+
}
353+
}
354+
257355
// See session() for details.
258356
private Session _create_session(Graph graph, object cfg, bool forceGpu)
259357
{
@@ -298,6 +396,50 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
298396
return new Session(graph);//, config = prepare_config(config))
299397
}
300398

399+
private Session _get_cached_session(
400+
Graph graph = null,
401+
object config = null,
402+
bool force_gpu = false,
403+
bool crash_if_inconsistent_args = true)
404+
{
405+
// See cached_session() for documentation.
406+
if (self._cached_session == null)
407+
{
408+
var sess = self._create_session(graph, config, force_gpu);
409+
self._cached_session = sess;
410+
self._cached_graph = graph;
411+
self._cached_config = config;
412+
self._cached_force_gpu = force_gpu;
413+
return sess;
414+
} else {
415+
416+
if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph))
417+
throw new ValueError(@"The graph used to get the cached session is
418+
different than the one that was used to create the
419+
session. Maybe create a new session with
420+
self.session()");
421+
if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) {
422+
throw new ValueError(@"The config used to get the cached session is
423+
different than the one that was used to create the
424+
session. Maybe create a new session with
425+
self.session()");
426+
}
427+
if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) {
428+
throw new ValueError(@"The force_gpu value used to get the cached session is
429+
different than the one that was used to create the
430+
session. Maybe create a new session with
431+
self.session()");
432+
}
433+
return _cached_session;
434+
}
435+
}
436+
437+
[TestCleanup]
438+
public void Cleanup()
439+
{
440+
_ClearCachedSession();
441+
}
442+
301443
#endregion
302444

303445
public void AssetSequenceEqual<T>(T[] a, T[] b)

0 commit comments

Comments
 (0)