Skip to content

Commit 49876bb

Browse files
committed
Change model.build parameter to input_shape.
1 parent 2adfcd2 commit 49876bb

21 files changed

+163
-139
lines changed

src/TensorFlowNET.Keras/Engine/Layer.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -191,15 +191,15 @@ protected void MaybeBuild(Tensors inputs)
191191
tf.Context.eager_mode(isFunc: tf.Context.is_build_function());
192192
}
193193

194-
build(inputs);
194+
build(inputs.shape);
195195

196196
if (need_restore_mode)
197197
tf.Context.restore_mode();
198198

199199
built = true;
200200
}
201201

202-
protected virtual void build(Tensors inputs)
202+
public virtual void build(Shape input_shape)
203203
{
204204
built = true;
205205
}

src/TensorFlowNET.Keras/Layers/Activation/ELU.cs

+33-25
Original file line numberDiff line numberDiff line change
@@ -6,30 +6,38 @@
66
using static Tensorflow.Binding;
77

88
namespace Tensorflow.Keras.Layers {
9-
/// <summary>
10-
/// ELU Layer:
11-
/// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere
12-
/// </summary>
13-
public class ELU : Layer {
14-
ELUArgs args;
15-
float alpha => args.Alpha;
16-
public ELU ( ELUArgs args ) : base(args) {
17-
this.args = args;
18-
}
19-
protected override void build ( Tensors inputs ) {
20-
if ( alpha < 0f ) {
21-
throw new ValueError("Alpha must be a number greater than 0.");
22-
}
23-
built = true;
24-
}
25-
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
26-
Tensor output = inputs;
27-
output = tf.where(output > 0f, output,
28-
tf.multiply(alpha, tf.sub(tf.exp(output), 1f)));
29-
return output;
30-
}
31-
public override Shape ComputeOutputShape ( Shape input_shape ) {
32-
return input_shape;
9+
/// <summary>
10+
/// ELU Layer:
11+
/// x = 0 when x > 0, x = alpha( e^x-1 ) elsewhere
12+
/// </summary>
13+
public class ELU : Layer
14+
{
15+
ELUArgs args;
16+
float alpha => args.Alpha;
17+
public ELU(ELUArgs args) : base(args)
18+
{
19+
this.args = args;
20+
}
21+
22+
public override void build(Shape input_shape)
23+
{
24+
if (alpha < 0f)
25+
{
26+
throw new ValueError("Alpha must be a number greater than 0.");
3327
}
34-
}
28+
built = true;
29+
}
30+
31+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
32+
{
33+
Tensor output = inputs;
34+
output = tf.where(output > 0f, output,
35+
tf.multiply(alpha, tf.sub(tf.exp(output), 1f)));
36+
return output;
37+
}
38+
public override Shape ComputeOutputShape(Shape input_shape)
39+
{
40+
return input_shape;
41+
}
42+
}
3543
}

src/TensorFlowNET.Keras/Layers/Activation/Exponential.cs

+20-15
Original file line numberDiff line numberDiff line change
@@ -6,19 +6,24 @@
66
using static Tensorflow.Binding;
77

88
namespace Tensorflow.Keras.Layers {
9-
public class Exponential : Layer {
10-
public Exponential ( LayerArgs args ) : base(args) {
11-
// Exponential has no args
12-
}
13-
protected override void build ( Tensors inputs ) {
14-
built = true;
15-
}
16-
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
17-
Tensor output = inputs;
18-
return tf.exp(output);
19-
}
20-
public override Shape ComputeOutputShape ( Shape input_shape ) {
21-
return input_shape;
22-
}
23-
}
9+
public class Exponential : Layer
10+
{
11+
public Exponential(LayerArgs args) : base(args)
12+
{
13+
// Exponential has no args
14+
}
15+
public override void build(Shape input_shape)
16+
{
17+
built = true;
18+
}
19+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
20+
{
21+
Tensor output = inputs;
22+
return tf.exp(output);
23+
}
24+
public override Shape ComputeOutputShape(Shape input_shape)
25+
{
26+
return input_shape;
27+
}
28+
}
2429
}

src/TensorFlowNET.Keras/Layers/Activation/SELU.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ public class SELU : Layer {
1515
public SELU ( LayerArgs args ) : base(args) {
1616
// SELU has no arguments
1717
}
18-
protected override void build ( Tensors inputs ) {
18+
public override void build(Shape input_shape) {
1919
if ( alpha < 0f ) {
2020
throw new ValueError("Alpha must be a number greater than 0.");
2121
}

src/TensorFlowNET.Keras/Layers/Attention/Attention.cs

+4-3
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,10 @@ public Attention(AttentionArgs args) : base(args)
9090
}.Contains(this.score_mode))
9191
throw new ValueError("Received: score_mode={score_mode}. Acceptable values are: [\"dot\", \"concat\"]");
9292
}
93-
93+
9494
// Creates variable when `use_scale` is True or `score_mode` is `concat`.
95-
protected override void build(Tensors inputs) {
95+
public override void build(Shape input_shape)
96+
{
9697
if (this.use_scale)
9798
this.scale = this.add_weight(name: "scale",
9899
shape: 1,
@@ -110,7 +111,7 @@ protected override void build(Tensors inputs) {
110111
trainable: true);
111112
else
112113
this.concat_score_weight = null;
113-
base.build(inputs);
114+
base.build(input_shape);
114115
}
115116

116117
/// <summary>

src/TensorFlowNET.Keras/Layers/Convolution/Conv2DTranspose.cs

+3-6
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,8 @@ public Conv2DTranspose(Conv2DArgs args) : base(args)
2929

3030
}
3131

32-
protected override void build(Tensors inputs)
32+
public override void build(Shape input_shape)
3333
{
34-
var input_shape = inputs.shape;
3534
if (len(input_shape) != 4)
3635
throw new ValueError($"Inputs should have rank 4. Received input shape: {input_shape}");
3736

@@ -43,14 +42,12 @@ protected override void build(Tensors inputs)
4342
shape: kernel_shape,
4443
initializer: kernel_initializer,
4544
regularizer: kernel_regularizer,
46-
trainable: true,
47-
dtype: inputs.dtype);
45+
trainable: true);
4846
if (use_bias)
4947
bias = add_weight(name: "bias",
5048
shape: filters,
5149
initializer: bias_initializer,
52-
trainable: true,
53-
dtype: inputs.dtype);
50+
trainable: true);
5451
built = true;
5552
}
5653

src/TensorFlowNET.Keras/Layers/Convolution/Convolutional.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,8 @@ public Convolutional(ConvolutionalArgs args) : base(args)
5757
_tf_data_format = conv_utils.convert_data_format(data_format, rank + 2);
5858
}
5959

60-
protected override void build(Tensors inputs)
60+
public override void build(Shape input_shape)
6161
{
62-
Shape input_shape = inputs.shape;
6362
int channel_axis = data_format == "channels_first" ? 1 : -1;
6463
var input_channel = channel_axis < 0 ?
6564
input_shape.dims[input_shape.ndim + channel_axis] :

src/TensorFlowNET.Keras/Layers/Core/Dense.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -41,9 +41,8 @@ public Dense(DenseArgs args) :
4141
this.inputSpec = new InputSpec(min_ndim: 2);
4242
}
4343

44-
protected override void build(Tensors inputs)
44+
public override void build(Shape input_shape)
4545
{
46-
Shape input_shape = inputs.shape;
4746
var last_dim = input_shape.dims.Last();
4847
var axes = new Dictionary<int, int>();
4948
axes[-1] = (int)last_dim;

src/TensorFlowNET.Keras/Layers/Core/EinsumDense.cs

+2-3
Original file line numberDiff line numberDiff line change
@@ -119,9 +119,8 @@ public EinsumDense(EinsumDenseArgs args) : base(args)
119119
this.bias_constraint = args.BiasConstraint;
120120
}
121121

122-
protected override void build(Tensors inputs)
122+
public override void build(Shape input_shape)
123123
{
124-
var input_shape = inputs.shape;
125124
var shape_data = _analyze_einsum_string(this.equation, this.bias_axes, input_shape, this.partial_output_shape);
126125
var kernel_shape = shape_data.Item1;
127126
var bias_shape = shape_data.Item2;
@@ -141,7 +140,7 @@ protected override void build(Tensors inputs)
141140
trainable: true);
142141
else
143142
this.bias = null;
144-
base.build(inputs);
143+
base.build(input_shape);
145144
}
146145

147146
public override Shape ComputeOutputShape(Shape input_shape)

src/TensorFlowNET.Keras/Layers/Core/Embedding.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ public Embedding(EmbeddingArgs args)
5454
SupportsMasking = mask_zero;
5555
}
5656

57-
protected override void build(Tensors inputs)
57+
public override void build(Shape input_shape)
5858
{
5959
tf.Context.eager_mode();
6060
embeddings = add_weight(shape: (input_dim, output_dim),

src/TensorFlowNET.Keras/Layers/Cropping/Cropping1D.cs

+51-39
Original file line numberDiff line numberDiff line change
@@ -2,49 +2,61 @@
22
using Tensorflow.Keras.Engine;
33

44
namespace Tensorflow.Keras.Layers {
5-
public class Cropping1D : Layer {
6-
CroppingArgs args;
7-
public Cropping1D ( CroppingArgs args ) : base(args) {
8-
this.args = args;
9-
}
5+
public class Cropping1D : Layer
6+
{
7+
CroppingArgs args;
8+
public Cropping1D(CroppingArgs args) : base(args)
9+
{
10+
this.args = args;
11+
}
1012

11-
protected override void build ( Tensors inputs ) {
12-
if ( args.cropping.rank != 1 ) {
13-
// throw an ValueError exception
14-
throw new ValueError("");
15-
}
16-
else if ( args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1 ) {
17-
throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
18-
}
19-
built = true;
13+
public override void build(Shape input_shape)
14+
{
15+
if (args.cropping.rank != 1)
16+
{
17+
// throw an ValueError exception
18+
throw new ValueError("");
19+
}
20+
else if (args.cropping.shape[0] > 2 || args.cropping.shape[0] < 1)
21+
{
22+
throw new ValueError("The `cropping` argument must be a tuple of 2 integers.");
2023
}
24+
built = true;
25+
}
2126

22-
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {
23-
Tensor output = inputs;
24-
if ( output.rank != 3 ) {
25-
// throw an ValueError exception
26-
throw new ValueError("Expected dim=3, found dim=" + output.rank);
27-
}
28-
if ( args.cropping.shape[0] == 1 ) {
29-
int crop_start = args.cropping[0];
30-
output = output[new Slice(), new Slice(crop_start, ( int ) output.shape[1] - crop_start), new Slice()];
31-
}
32-
else {
33-
int crop_start = args.cropping[0], crop_end = args.cropping[1];
34-
output = output[new Slice(), new Slice(crop_start, ( int ) (output.shape[1]) - crop_end), new Slice()];
35-
}
36-
return output;
27+
protected override Tensors Call(Tensors inputs, Tensor state = null, bool? training = null)
28+
{
29+
Tensor output = inputs;
30+
if (output.rank != 3)
31+
{
32+
// throw an ValueError exception
33+
throw new ValueError("Expected dim=3, found dim=" + output.rank);
34+
}
35+
if (args.cropping.shape[0] == 1)
36+
{
37+
int crop_start = args.cropping[0];
38+
output = output[new Slice(), new Slice(crop_start, (int)output.shape[1] - crop_start), new Slice()];
3739
}
40+
else
41+
{
42+
int crop_start = args.cropping[0], crop_end = args.cropping[1];
43+
output = output[new Slice(), new Slice(crop_start, (int)(output.shape[1]) - crop_end), new Slice()];
44+
}
45+
return output;
46+
}
3847

39-
public override Shape ComputeOutputShape ( Shape input_shape ) {
40-
if ( args.cropping.shape[0] == 1 ) {
41-
int crop = args.cropping[0];
42-
return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop * 2), ( int ) (input_shape[2]));
43-
}
44-
else {
45-
int crop_start = args.cropping[0], crop_end = args.cropping[1];
46-
return new Shape(( int ) (input_shape[0]), ( int ) (input_shape[1] - crop_start - crop_end), ( int ) (input_shape[2]));
47-
}
48+
public override Shape ComputeOutputShape(Shape input_shape)
49+
{
50+
if (args.cropping.shape[0] == 1)
51+
{
52+
int crop = args.cropping[0];
53+
return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop * 2), (int)(input_shape[2]));
54+
}
55+
else
56+
{
57+
int crop_start = args.cropping[0], crop_end = args.cropping[1];
58+
return new Shape((int)(input_shape[0]), (int)(input_shape[1] - crop_start - crop_end), (int)(input_shape[2]));
4859
}
49-
}
60+
}
61+
}
5062
}

src/TensorFlowNET.Keras/Layers/Cropping/Cropping2D.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ public class Cropping2D : Layer {
1212
public Cropping2D ( Cropping2DArgs args ) : base(args) {
1313
this.args = args;
1414
}
15-
protected override void build ( Tensors inputs ) {
15+
public override void build(Shape input_shape) {
1616
built = true;
1717
}
1818
protected override Tensors Call ( Tensors inputs, Tensor state = null, bool? training = null ) {

src/TensorFlowNET.Keras/Layers/Cropping/Cropping3D.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ public Cropping3D ( Cropping3DArgs args ) : base(args) {
1111
this.args = args;
1212
}
1313

14-
protected override void build ( Tensors inputs ) {
14+
public override void build(Shape input_shape) {
1515
built = true;
1616
}
1717

src/TensorFlowNET.Keras/Layers/Merging/Concatenate.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ public Concatenate(MergeArgs args) : base(args)
2323
this.args = args;
2424
}
2525

26-
protected override void build(Tensors inputs)
26+
public override void build(Shape input_shape)
2727
{
2828
/*var shape_set = new HashSet<Shape>();
2929
var reduced_inputs_shapes = inputs.Select(x => x.shape).ToArray();

src/TensorFlowNET.Keras/Layers/Merging/Merge.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ public Merge(MergeArgs args) : base(args)
1414

1515
}
1616

17-
protected override void build(Tensors inputs)
17+
public override void build(Shape input_shape)
1818
{
1919
// output_shape = input_shape.dims[1^];
2020
}

src/TensorFlowNET.Keras/Layers/Normalization/BatchNormalization.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -53,9 +53,8 @@ public BatchNormalization(BatchNormalizationArgs args) : base(args)
5353
axis = args.Axis.dims.Select(x => (int)x).ToArray();
5454
}
5555

56-
protected override void build(Tensors inputs)
56+
public override void build(Shape input_shape)
5757
{
58-
Shape input_shape = inputs.shape;
5958
var ndims = input_shape.ndim;
6059
foreach (var (idx, x) in enumerate(axis))
6160
if (x < 0)

src/TensorFlowNET.Keras/Layers/Normalization/LayerNormalization.cs

+1-2
Original file line numberDiff line numberDiff line change
@@ -49,9 +49,8 @@ public LayerNormalization(LayerNormalizationArgs args) : base(args)
4949
axis = args.Axis.axis;
5050
}
5151

52-
protected override void build(Tensors inputs)
52+
public override void build(Shape input_shape)
5353
{
54-
Shape input_shape = inputs.shape;
5554
var ndims = input_shape.ndim;
5655
foreach (var (idx, x) in enumerate(axis))
5756
if (x < 0)

0 commit comments

Comments
 (0)