Skip to content

Commit 448246b

Browse files
NucsOceania2018
authored andcommitted
Supported for forced singlethreading
- Separated multithreading related methods to classname.threading.cs partial file - ops: Added enforce_singlethreading(), enforce_multithreading()
1 parent 11224e4 commit 448246b

File tree

9 files changed

+319
-70
lines changed

9 files changed

+319
-70
lines changed

src/TensorFlowNET.Core/Sessions/Session.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -37,8 +37,7 @@ public Session(Graph g, ConfigProto config = null, Status s = null) : base("", g
3737

3838
public Session as_default()
3939
{
40-
tf._defaultSessionFactory.Value = this;
41-
return this;
40+
return ops.set_default_session(this);
4241
}
4342

4443
[MethodImpl(MethodImplOptions.NoOptimization)]

src/TensorFlowNET.Core/ops.cs

-60
Original file line numberDiff line numberDiff line change
@@ -28,10 +28,6 @@ namespace Tensorflow
2828
{
2929
public partial class ops
3030
{
31-
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());
32-
33-
public static DefaultGraphStack default_graph_stack => _defaultGraphFactory.Value;
34-
3531
public static int tensor_id(Tensor tensor)
3632
{
3733
return tensor.Id;
@@ -78,53 +74,6 @@ public static List<T> get_collection_ref<T>(string key)
7874
return get_default_graph().get_collection_ref<T>(key);
7975
}
8076

81-
/// <summary>
82-
/// Returns the default graph for the current thread.
83-
///
84-
/// The returned graph will be the innermost graph on which a
85-
/// `Graph.as_default()` context has been entered, or a global default
86-
/// graph if none has been explicitly created.
87-
///
88-
/// NOTE: The default graph is a property of the current thread.If you
89-
/// create a new thread, and wish to use the default graph in that
90-
/// thread, you must explicitly add a `with g.as_default():` in that
91-
/// thread's function.
92-
/// </summary>
93-
/// <returns></returns>
94-
public static Graph get_default_graph()
95-
{
96-
//TODO: original source indicates there should be a _default_graph_stack!
97-
//return _default_graph_stack.get_default()
98-
return default_graph_stack.get_controller();
99-
}
100-
101-
public static Graph set_default_graph(Graph graph)
102-
{
103-
//TODO: original source does not have a 'set_default_graph' and indicates there should be a _default_graph_stack!
104-
default_graph_stack.set_controller(graph);
105-
return default_graph_stack.get_controller();
106-
}
107-
108-
/// <summary>
109-
/// Clears the default graph stack and resets the global default graph.
110-
///
111-
/// NOTE: The default graph is a property of the current thread.This
112-
/// function applies only to the current thread.Calling this function while
113-
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
114-
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
115-
/// after calling this function will result in undefined behavior.
116-
/// </summary>
117-
/// <returns></returns>
118-
public static void reset_default_graph()
119-
{
120-
//TODO: original source indicates there should be a _default_graph_stack!
121-
//if (!_default_graph_stack.is_cleared())
122-
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
123-
// "nested graphs. If you need a cleared graph, " +
124-
// "exit the nesting and create a new graph.");
125-
default_graph_stack.reset();
126-
}
127-
12877
public static Graph _get_graph_from_inputs(params Tensor[] op_input_list)
12978
=> _get_graph_from_inputs(op_input_list: op_input_list, graph: null);
13079

@@ -399,15 +348,6 @@ public static NDArray _eval_using_default_session(Tensor tensor, FeedItem[] feed
399348
return session.run(tensor, feed_dict);
400349
}
401350

402-
/// <summary>
403-
/// Returns the default session for the current thread.
404-
/// </summary>
405-
/// <returns>The default `Session` being used in the current thread.</returns>
406-
public static Session get_default_session()
407-
{
408-
return tf.defaultSession;
409-
}
410-
411351
/// <summary>
412352
/// Prepends name scope to a name.
413353
/// </summary>
+152
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,152 @@
1+
using System.Threading;
2+
using Tensorflow.Util;
3+
using static Tensorflow.Binding;
4+
5+
namespace Tensorflow
6+
{
7+
public partial class ops
8+
{
9+
private static readonly ThreadLocal<DefaultGraphStack> _defaultGraphFactory = new ThreadLocal<DefaultGraphStack>(() => new DefaultGraphStack());
10+
private static volatile Session _singleSesson;
11+
private static volatile DefaultGraphStack _singleGraphStack;
12+
private static readonly object _threadingLock = new object();
13+
14+
public static DefaultGraphStack default_graph_stack
15+
{
16+
get
17+
{
18+
if (!isSingleThreaded)
19+
return _defaultGraphFactory.Value;
20+
21+
if (_singleGraphStack == null)
22+
{
23+
lock (_threadingLock)
24+
{
25+
if (_singleGraphStack == null)
26+
_singleGraphStack = new DefaultGraphStack();
27+
}
28+
}
29+
30+
return _singleGraphStack;
31+
}
32+
}
33+
34+
private static bool isSingleThreaded = false;
35+
36+
/// <summary>
37+
/// Does this library ignore different thread accessing.
38+
/// </summary>
39+
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading </remarks>
40+
public static bool IsSingleThreaded
41+
{
42+
get => isSingleThreaded;
43+
set
44+
{
45+
if (value)
46+
enforce_singlethreading();
47+
else
48+
enforce_multithreading();
49+
}
50+
}
51+
52+
/// <summary>
53+
/// Forces the library to ignore different thread accessing.
54+
/// </summary>
55+
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a multithreaded manner</remarks>
56+
public static void enforce_singlethreading()
57+
{
58+
isSingleThreaded = true;
59+
}
60+
61+
/// <summary>
62+
/// Forces the library to provide a separate <see cref="Session"/> and <see cref="Graph"/> to every different thread accessing.
63+
/// </summary>
64+
/// <remarks>https://github.com/SciSharp/TensorFlow.NET/wiki/Multithreading <br></br>Note that this discards any sessions and graphs used in a singlethreaded manner</remarks>
65+
public static void enforce_multithreading()
66+
{
67+
isSingleThreaded = false;
68+
}
69+
70+
/// <summary>
71+
/// Returns the default session for the current thread.
72+
/// </summary>
73+
/// <returns>The default `Session` being used in the current thread.</returns>
74+
public static Session get_default_session()
75+
{
76+
if (!isSingleThreaded)
77+
return tf.defaultSession;
78+
79+
if (_singleSesson == null)
80+
{
81+
lock (_threadingLock)
82+
{
83+
if (_singleSesson == null)
84+
_singleSesson = new Session();
85+
}
86+
}
87+
88+
return _singleSesson;
89+
}
90+
91+
/// <summary>
92+
/// Returns the default session for the current thread.
93+
/// </summary>
94+
/// <returns>The default `Session` being used in the current thread.</returns>
95+
public static Session set_default_session(Session sess)
96+
{
97+
if (!isSingleThreaded)
98+
return tf.defaultSession = sess;
99+
100+
lock (_threadingLock)
101+
{
102+
_singleSesson = sess;
103+
}
104+
105+
return _singleSesson;
106+
}
107+
108+
/// <summary>
109+
/// Returns the default graph for the current thread.
110+
///
111+
/// The returned graph will be the innermost graph on which a
112+
/// `Graph.as_default()` context has been entered, or a global default
113+
/// graph if none has been explicitly created.
114+
///
115+
/// NOTE: The default graph is a property of the current thread.If you
116+
/// create a new thread, and wish to use the default graph in that
117+
/// thread, you must explicitly add a `with g.as_default():` in that
118+
/// thread's function.
119+
/// </summary>
120+
/// <returns></returns>
121+
public static Graph get_default_graph()
122+
{
123+
//return _default_graph_stack.get_default()
124+
return default_graph_stack.get_controller();
125+
}
126+
127+
public static Graph set_default_graph(Graph graph)
128+
{
129+
default_graph_stack.set_controller(graph);
130+
return default_graph_stack.get_controller();
131+
}
132+
133+
/// <summary>
134+
/// Clears the default graph stack and resets the global default graph.
135+
///
136+
/// NOTE: The default graph is a property of the current thread.This
137+
/// function applies only to the current thread.Calling this function while
138+
/// a `tf.Session` or `tf.InteractiveSession` is active will result in undefined
139+
/// behavior. Using any previously created `tf.Operation` or `tf.Tensor` objects
140+
/// after calling this function will result in undefined behavior.
141+
/// </summary>
142+
/// <returns></returns>
143+
public static void reset_default_graph()
144+
{
145+
//if (!_default_graph_stack.is_cleared())
146+
// throw new InvalidOperationException("Do not use tf.reset_default_graph() to clear " +
147+
// "nested graphs. If you need a cleared graph, " +
148+
// "exit the nesting and create a new graph.");
149+
default_graph_stack.reset();
150+
}
151+
}
152+
}

src/TensorFlowNET.Core/tensorflow.cs

+2-4
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ namespace Tensorflow
2121
{
2222
public partial class tensorflow : IObjectLife
2323
{
24-
protected internal readonly ThreadLocal<Session> _defaultSessionFactory;
25-
2624
public TF_DataType @byte = TF_DataType.TF_UINT8;
2725
public TF_DataType @sbyte = TF_DataType.TF_INT8;
2826
public TF_DataType int16 = TF_DataType.TF_INT16;
@@ -40,10 +38,10 @@ public partial class tensorflow : IObjectLife
4038

4139
public tensorflow()
4240
{
43-
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session());
41+
_constructThreadingObjects();
4442
}
4543

46-
public Session defaultSession => _defaultSessionFactory.Value;
44+
4745

4846
public RefVariable Variable<T>(T data,
4947
bool trainable = true,
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using System.Runtime.CompilerServices;
18+
using System.Threading;
19+
20+
namespace Tensorflow
21+
{
22+
public partial class tensorflow : IObjectLife
23+
{
24+
protected ThreadLocal<Session> _defaultSessionFactory;
25+
26+
[MethodImpl(MethodImplOptions.AggressiveInlining)]
27+
public void _constructThreadingObjects()
28+
{
29+
_defaultSessionFactory = new ThreadLocal<Session>(() => new Session());
30+
}
31+
32+
public Session defaultSession
33+
{
34+
get
35+
{
36+
if (!ops.IsSingleThreaded)
37+
return _defaultSessionFactory.Value;
38+
39+
return ops.get_default_session();
40+
}
41+
internal set
42+
{
43+
if (!ops.IsSingleThreaded)
44+
{
45+
_defaultSessionFactory.Value = value;
46+
return;
47+
}
48+
49+
ops.set_default_session(value);
50+
}
51+
}
52+
}
53+
}

0 commit comments

Comments
 (0)