-
Notifications
You must be signed in to change notification settings - Fork 536
cached_session for graph tests #1172
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -388,22 +388,19 @@ public void testBoundaryStop() | |
|
||
} | ||
|
||
[Ignore("TODO")] | ||
[TestMethod] | ||
public void testBoundaryContinue() | ||
{ | ||
//@test_util.run_v1_only("b/120545219") | ||
//def testBoundaryContinue(self): | ||
// # Test that we differentiate both 'x' and 'y' correctly when x is a | ||
// # predecessor of y. | ||
// with self.cached_session(): | ||
// x = constant(1.0) | ||
// y = x * 2.0 | ||
// z = y * 3.0 | ||
// grads = gradients.gradients(z, [x, y]) | ||
// self.assertTrue(all(x is not None for x in grads)) | ||
// self.assertEqual(6.0, grads[0].eval()) | ||
// Test that we differentiate both 'x' and 'y' correctly when x is a | ||
// predecessor of y. | ||
|
||
self.cached_session(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Oceania2018 I'm not sure how exactly "with" should be ported there, since "IEnumerable" isn't "IDisposable" and we don't need to call any "Dispose" for the enumerable itself. I guess session's "Dispose" should be called automatically. |
||
var x = tf.constant(1.0); | ||
var y = x * 2.0; | ||
var z = y * 3.0; | ||
var grads = tf.gradients(z, new[] { x, y }); | ||
self.assertTrue(all(grads.Select(x => x != null))); | ||
self.assertEqual(6.0, grads[0].eval()); | ||
} | ||
|
||
[Ignore("TODO")] | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -6,6 +6,8 @@ | |
using System.Linq; | ||
using Tensorflow; | ||
using static Tensorflow.Binding; | ||
using OneOf.Types; | ||
using System.Collections.Generic; | ||
|
||
namespace TensorFlowNET.UnitTest | ||
{ | ||
|
@@ -139,6 +141,21 @@ public void assertProtoEquals(object toProto, object o) | |
|
||
#region tensor evaluation and test session | ||
|
||
private Session _cached_session = null; | ||
private Graph _cached_graph = null; | ||
private object _cached_config = null; | ||
private bool _cached_force_gpu = false; | ||
|
||
private void _ClearCachedSession() | ||
{ | ||
if (self._cached_session != null) | ||
{ | ||
self._cached_session.Dispose(); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @Oceania2018 should we cleanup the graph and the config as well? |
||
self._cached_session = null; | ||
} | ||
} | ||
|
||
|
||
//protected object _eval_helper(Tensor[] tensors) | ||
//{ | ||
// if (tensors == null) | ||
|
@@ -203,10 +220,57 @@ public T evaluate<T>(Tensor tensor) | |
} | ||
} | ||
|
||
|
||
public Session cached_session() | ||
///Returns a TensorFlow Session for use in executing tests. | ||
public IEnumerable<Session> cached_session( | ||
Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) | ||
{ | ||
throw new NotImplementedException(); | ||
// This method behaves differently than self.session(): for performance reasons | ||
// `cached_session` will by default reuse the same session within the same | ||
// test.The session returned by this function will only be closed at the end | ||
// of the test(in the TearDown function). | ||
|
||
// Use the `use_gpu` and `force_gpu` options to control where ops are run.If | ||
// `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if | ||
// `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as | ||
// possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to | ||
// the CPU. | ||
|
||
// Example: | ||
// python | ||
// class MyOperatorTest(test_util.TensorFlowTestCase) : | ||
// def testMyOperator(self): | ||
// with self.cached_session() as sess: | ||
// valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] | ||
// result = MyOperator(valid_input).eval() | ||
// self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] | ||
// invalid_input = [-1.0, 2.0, 7.0] | ||
// with self.assertRaisesOpError("negative input not supported"): | ||
// MyOperator(invalid_input).eval() | ||
|
||
|
||
// Args: | ||
// graph: Optional graph to use during the returned session. | ||
// config: An optional config_pb2.ConfigProto to use to configure the | ||
// session. | ||
// use_gpu: If True, attempt to run as many ops as possible on GPU. | ||
// force_gpu: If True, pin all ops to `/device:GPU:0`. | ||
|
||
// Yields: | ||
// A Session object that should be used as a context manager to surround | ||
// the graph building and execution code in a test case. | ||
|
||
|
||
// TODO: | ||
// if context.executing_eagerly(): | ||
// return self._eval_helper(tensors) | ||
// else: | ||
{ | ||
var sess = self._get_cached_session( | ||
graph, config, force_gpu, crash_if_inconsistent_args: true); | ||
var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); | ||
return cached; | ||
|
||
} | ||
} | ||
|
||
//Returns a TensorFlow Session for use in executing tests. | ||
|
@@ -254,6 +318,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu = | |
return s.as_default(); | ||
} | ||
|
||
private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) | ||
{ | ||
// Set the session and its graph to global default and constrain devices.""" | ||
// if context.executing_eagerly(): | ||
// yield None | ||
// else: | ||
{ | ||
sess.graph.as_default(); | ||
sess.as_default(); | ||
{ | ||
if (force_gpu) | ||
{ | ||
// TODO: | ||
|
||
// Use the name of an actual device if one is detected, or | ||
// '/device:GPU:0' otherwise | ||
/* var gpu_name = gpu_device_name(); | ||
if (!gpu_name) | ||
gpu_name = "/device:GPU:0" | ||
using (sess.graph.device(gpu_name)) { | ||
yield return sess; | ||
}*/ | ||
yield return sess; | ||
} | ||
else if (use_gpu) | ||
yield return sess; | ||
else | ||
using (sess.graph.device("/device:CPU:0")) | ||
yield return sess; | ||
} | ||
|
||
} | ||
} | ||
|
||
// See session() for details. | ||
private Session _create_session(Graph graph, object cfg, bool forceGpu) | ||
{ | ||
|
@@ -298,6 +396,50 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu) | |
return new Session(graph);//, config = prepare_config(config)) | ||
} | ||
|
||
private Session _get_cached_session( | ||
Graph graph = null, | ||
object config = null, | ||
bool force_gpu = false, | ||
bool crash_if_inconsistent_args = true) | ||
{ | ||
// See cached_session() for documentation. | ||
if (self._cached_session == null) | ||
{ | ||
var sess = self._create_session(graph, config, force_gpu); | ||
self._cached_session = sess; | ||
self._cached_graph = graph; | ||
self._cached_config = config; | ||
self._cached_force_gpu = force_gpu; | ||
return sess; | ||
} else { | ||
|
||
if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph)) | ||
throw new ValueError(@"The graph used to get the cached session is | ||
different than the one that was used to create the | ||
session. Maybe create a new session with | ||
self.session()"); | ||
if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) { | ||
throw new ValueError(@"The config used to get the cached session is | ||
different than the one that was used to create the | ||
session. Maybe create a new session with | ||
self.session()"); | ||
} | ||
if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) { | ||
throw new ValueError(@"The force_gpu value used to get the cached session is | ||
different than the one that was used to create the | ||
session. Maybe create a new session with | ||
self.session()"); | ||
} | ||
return _cached_session; | ||
} | ||
} | ||
|
||
[TestCleanup] | ||
public void Cleanup() | ||
{ | ||
_ClearCachedSession(); | ||
} | ||
|
||
#endregion | ||
|
||
public void AssetSequenceEqual<T>(T[] a, T[] b) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@Oceania2018 I'm not sure how exactly IEnumerable should work there since Python generator produces only one instance. Anyway this test wasn't working and it is still not implemented yet.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The test didn't come through.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sorry, it was flaky. Now it should work.
Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the different behavior of
yield
in Python and C#, If cached_session only need to have one session, I think it is better not to useyield
? And then you can useusing
to replacewith
in GradientTest.csUh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could, please, explain me why
yield
used in Python in this case?Anyway fixed.