Skip to content

Commit 0a08386

Browse files
committed
Fix batch_size for Keras Input.
1 parent def0664 commit 0a08386

File tree

3 files changed

+4
-0
lines changed

3 files changed

+4
-0
lines changed

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

+1
Original file line numberDiff line numberDiff line change
@@ -108,6 +108,7 @@ public ILayer EinsumDense(string equation,
108108
public ILayer GlobalMaxPooling2D(string data_format = "channels_last");
109109

110110
public Tensors Input(Shape shape,
111+
int batch_size = -1,
111112
string name = null,
112113
bool sparse = false,
113114
bool ragged = false);

src/TensorFlowNET.Keras/Layers/LayersApi.cs

+2
Original file line numberDiff line numberDiff line change
@@ -469,13 +469,15 @@ public ILayer Flatten(string data_format = null)
469469
/// </param>
470470
/// <returns>A tensor.</returns>
471471
public Tensors Input(Shape shape,
472+
int batch_size = -1,
472473
string name = null,
473474
bool sparse = false,
474475
bool ragged = false)
475476
{
476477
var input_layer = new InputLayer(new InputLayerArgs
477478
{
478479
InputShape = shape,
480+
BatchSize= batch_size,
479481
Name = name,
480482
Sparse = sparse,
481483
Ragged = ragged

src/python/xor_keras.py

+1
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
66
print(tf.__version__)
7+
# https://playground.tensorflow.org/
78
# tf.compat.v1.enable_eager_execution()
89
# tf.debugging.set_log_device_placement(True);
910
tf.config.run_functions_eagerly(True)

0 commit comments

Comments
 (0)