|
6 | 6 | using System.Linq;
|
7 | 7 | using Tensorflow;
|
8 | 8 | using static Tensorflow.Binding;
|
| 9 | +using OneOf.Types; |
| 10 | +using System.Collections.Generic; |
9 | 11 |
|
10 | 12 | namespace TensorFlowNET.UnitTest
|
11 | 13 | {
|
@@ -139,6 +141,21 @@ public void assertProtoEquals(object toProto, object o)
|
139 | 141 |
|
140 | 142 | #region tensor evaluation and test session
|
141 | 143 |
|
| 144 | + private Session _cached_session = null; |
| 145 | + private Graph _cached_graph = null; |
| 146 | + private object _cached_config = null; |
| 147 | + private bool _cached_force_gpu = false; |
| 148 | + |
| 149 | + private void _ClearCachedSession() |
| 150 | + { |
| 151 | + if (self._cached_session != null) |
| 152 | + { |
| 153 | + self._cached_session.Dispose(); |
| 154 | + self._cached_session = null; |
| 155 | + } |
| 156 | + } |
| 157 | + |
| 158 | + |
142 | 159 | //protected object _eval_helper(Tensor[] tensors)
|
143 | 160 | //{
|
144 | 161 | // if (tensors == null)
|
@@ -203,10 +220,57 @@ public T evaluate<T>(Tensor tensor)
|
203 | 220 | }
|
204 | 221 | }
|
205 | 222 |
|
206 |
| - |
207 |
| - public Session cached_session() |
| 223 | + ///Returns a TensorFlow Session for use in executing tests. |
| 224 | + public IEnumerable<Session> cached_session( |
| 225 | + Graph graph = null, object config = null, bool use_gpu = false, bool force_gpu = false) |
208 | 226 | {
|
209 |
| - throw new NotImplementedException(); |
| 227 | + // This method behaves differently than self.session(): for performance reasons |
| 228 | + // `cached_session` will by default reuse the same session within the same |
| 229 | + // test.The session returned by this function will only be closed at the end |
| 230 | + // of the test(in the TearDown function). |
| 231 | + |
| 232 | + // Use the `use_gpu` and `force_gpu` options to control where ops are run.If |
| 233 | + // `force_gpu` is True, all ops are pinned to `/ device:GPU:0`. Otherwise, if |
| 234 | + // `use_gpu` is True, TensorFlow tries to run as many ops on the GPU as |
| 235 | + // possible.If both `force_gpu and `use_gpu` are False, all ops are pinned to |
| 236 | + // the CPU. |
| 237 | + |
| 238 | + // Example: |
| 239 | + // python |
| 240 | + // class MyOperatorTest(test_util.TensorFlowTestCase) : |
| 241 | + // def testMyOperator(self): |
| 242 | + // with self.cached_session() as sess: |
| 243 | + // valid_input = [1.0, 2.0, 3.0, 4.0, 5.0] |
| 244 | + // result = MyOperator(valid_input).eval() |
| 245 | + // self.assertEqual(result, [1.0, 2.0, 3.0, 5.0, 8.0] |
| 246 | + // invalid_input = [-1.0, 2.0, 7.0] |
| 247 | + // with self.assertRaisesOpError("negative input not supported"): |
| 248 | + // MyOperator(invalid_input).eval() |
| 249 | + |
| 250 | + |
| 251 | + // Args: |
| 252 | + // graph: Optional graph to use during the returned session. |
| 253 | + // config: An optional config_pb2.ConfigProto to use to configure the |
| 254 | + // session. |
| 255 | + // use_gpu: If True, attempt to run as many ops as possible on GPU. |
| 256 | + // force_gpu: If True, pin all ops to `/device:GPU:0`. |
| 257 | + |
| 258 | + // Yields: |
| 259 | + // A Session object that should be used as a context manager to surround |
| 260 | + // the graph building and execution code in a test case. |
| 261 | + |
| 262 | + |
| 263 | + // TODO: |
| 264 | + // if context.executing_eagerly(): |
| 265 | + // return self._eval_helper(tensors) |
| 266 | + // else: |
| 267 | + { |
| 268 | + var sess = self._get_cached_session( |
| 269 | + graph, config, force_gpu, crash_if_inconsistent_args: true); |
| 270 | + var cached = self._constrain_devices_and_set_default(sess, use_gpu, force_gpu); |
| 271 | + return cached; |
| 272 | + |
| 273 | + } |
210 | 274 | }
|
211 | 275 |
|
212 | 276 | //Returns a TensorFlow Session for use in executing tests.
|
@@ -254,6 +318,40 @@ public Session session(Graph graph = null, object config = null, bool use_gpu =
|
254 | 318 | return s.as_default();
|
255 | 319 | }
|
256 | 320 |
|
| 321 | + private IEnumerable<Session> _constrain_devices_and_set_default(Session sess, bool use_gpu, bool force_gpu) |
| 322 | + { |
| 323 | + // Set the session and its graph to global default and constrain devices.""" |
| 324 | + // if context.executing_eagerly(): |
| 325 | + // yield None |
| 326 | + // else: |
| 327 | + { |
| 328 | + sess.graph.as_default(); |
| 329 | + sess.as_default(); |
| 330 | + { |
| 331 | + if (force_gpu) |
| 332 | + { |
| 333 | + // TODO: |
| 334 | + |
| 335 | + // Use the name of an actual device if one is detected, or |
| 336 | + // '/device:GPU:0' otherwise |
| 337 | + /* var gpu_name = gpu_device_name(); |
| 338 | + if (!gpu_name) |
| 339 | + gpu_name = "/device:GPU:0" |
| 340 | + using (sess.graph.device(gpu_name)) { |
| 341 | + yield return sess; |
| 342 | + }*/ |
| 343 | + yield return sess; |
| 344 | + } |
| 345 | + else if (use_gpu) |
| 346 | + yield return sess; |
| 347 | + else |
| 348 | + using (sess.graph.device("/device:CPU:0")) |
| 349 | + yield return sess; |
| 350 | + } |
| 351 | + |
| 352 | + } |
| 353 | + } |
| 354 | + |
257 | 355 | // See session() for details.
|
258 | 356 | private Session _create_session(Graph graph, object cfg, bool forceGpu)
|
259 | 357 | {
|
@@ -298,6 +396,50 @@ private Session _create_session(Graph graph, object cfg, bool forceGpu)
|
298 | 396 | return new Session(graph);//, config = prepare_config(config))
|
299 | 397 | }
|
300 | 398 |
|
| 399 | + private Session _get_cached_session( |
| 400 | + Graph graph = null, |
| 401 | + object config = null, |
| 402 | + bool force_gpu = false, |
| 403 | + bool crash_if_inconsistent_args = true) |
| 404 | + { |
| 405 | + // See cached_session() for documentation. |
| 406 | + if (self._cached_session == null) |
| 407 | + { |
| 408 | + var sess = self._create_session(graph, config, force_gpu); |
| 409 | + self._cached_session = sess; |
| 410 | + self._cached_graph = graph; |
| 411 | + self._cached_config = config; |
| 412 | + self._cached_force_gpu = force_gpu; |
| 413 | + return sess; |
| 414 | + } else { |
| 415 | + |
| 416 | + if (crash_if_inconsistent_args && !self._cached_graph.Equals(graph)) |
| 417 | + throw new ValueError(@"The graph used to get the cached session is |
| 418 | + different than the one that was used to create the |
| 419 | + session. Maybe create a new session with |
| 420 | + self.session()"); |
| 421 | + if (crash_if_inconsistent_args && !self._cached_config.Equals(config)) { |
| 422 | + throw new ValueError(@"The config used to get the cached session is |
| 423 | + different than the one that was used to create the |
| 424 | + session. Maybe create a new session with |
| 425 | + self.session()"); |
| 426 | + } |
| 427 | + if (crash_if_inconsistent_args && !self._cached_force_gpu.Equals(force_gpu)) { |
| 428 | + throw new ValueError(@"The force_gpu value used to get the cached session is |
| 429 | + different than the one that was used to create the |
| 430 | + session. Maybe create a new session with |
| 431 | + self.session()"); |
| 432 | + } |
| 433 | + return _cached_session; |
| 434 | + } |
| 435 | + } |
| 436 | + |
| 437 | + [TestCleanup] |
| 438 | + public void Cleanup() |
| 439 | + { |
| 440 | + _ClearCachedSession(); |
| 441 | + } |
| 442 | + |
301 | 443 | #endregion
|
302 | 444 |
|
303 | 445 | public void AssetSequenceEqual<T>(T[] a, T[] b)
|
|
0 commit comments