Skip to content

Commit 2ea8f2e

Browse files
committed
fix is_build_function in Layer.
1 parent f7412a1 commit 2ea8f2e

File tree

4 files changed

+24
-17
lines changed

4 files changed

+24
-17
lines changed

src/TensorFlowNET.Core/Contexts/Context.cs

+8-10
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@ public Context(ContextOptions opts, Status status)
4242
{
4343
Handle = c_api.TFE_NewContext(opts.Handle, status.Handle);
4444
status.Check(true);
45-
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE);
45+
context_switches = new ContextSwitchStack(defaultExecutionMode == EAGER_MODE, false);
4646
initialized = true;
4747
}
4848

@@ -70,21 +70,19 @@ public void end_step()
7070
public bool executing_eagerly()
7171
=> context_switches.Current().EagerMode;
7272

73+
public bool is_build_function()
74+
=> context_switches.Current().IsBuildingFunction;
75+
7376
public string shared_name(string name = null)
7477
=> !string.IsNullOrEmpty(name) || !executing_eagerly() ?
7578
name :
7679
"cd2c89b7-88b7-44c8-ad83-06c2a9158347";
7780

78-
public void graph_mode()
79-
=> mode(false);
80-
81-
public void eager_mode()
82-
=> mode(true);
81+
public void graph_mode(bool isFunc = false)
82+
=> context_switches.Push(false, isFunc);
8383

84-
void mode(bool isEager)
85-
{
86-
context_switches.Push(isEager);
87-
}
84+
public void eager_mode(bool isFunc = false)
85+
=> context_switches.Push(true, isFunc);
8886

8987
public void restore_mode()
9088
{

src/TensorFlowNET.Core/Contexts/ContextSwitchStack.cs

+5-4
Original file line numberDiff line numberDiff line change
@@ -25,17 +25,18 @@ public class ContextSwitchStack
2525
{
2626
Stack<ContextSwitch> stack;
2727

28-
public ContextSwitchStack(bool isEager)
28+
public ContextSwitchStack(bool isEager, bool isFunc)
2929
{
3030
stack = new Stack<ContextSwitch>();
31-
Push(isEager);
31+
Push(isEager, isFunc);
3232
}
3333

34-
public void Push(bool isEager)
34+
public void Push(bool isEager, bool isFunc)
3535
{
3636
stack.Push(new ContextSwitch
3737
{
38-
EagerMode = isEager
38+
EagerMode = isEager,
39+
IsBuildingFunction = isFunc
3940
});
4041
}
4142

src/TensorFlowNET.Keras/Engine/Layer.FunctionalConstructionCall.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ Tensors FunctionalConstructionCall(Tensors inputs)
2525
// using var graph = tf.keras.backend.get_graph().as_default();
2626

2727
if (!inputs.IsEagerTensor)
28-
tf.Context.graph_mode();
28+
tf.Context.graph_mode(isFunc: true);
2929

3030
tf_with(ops.name_scope(_name_scope()), scope =>
3131
{

src/TensorFlowNET.Keras/Engine/Layer.cs

+10-2
Original file line numberDiff line numberDiff line change
@@ -176,9 +176,17 @@ protected void MaybeBuild(Tensors inputs)
176176

177177
tf.init_scope();
178178

179-
tf.Context.eager_mode();
179+
bool need_restore_mode = false;
180+
if (inputs.IsEagerTensor || tf.Context.is_build_function())
181+
{
182+
need_restore_mode = true;
183+
tf.Context.eager_mode();
184+
}
185+
180186
build(inputs);
181-
tf.Context.restore_mode();
187+
188+
if (need_restore_mode)
189+
tf.Context.restore_mode();
182190

183191
built = true;
184192
}

0 commit comments

Comments
 (0)