Skip to content

Commit ba8f0b0

Browse files
committed
add DepthwiseConv2D (深度可分离卷积)
1 parent 43c3705 commit ba8f0b0

File tree

9 files changed

+425
-1
lines changed

9 files changed

+425
-1
lines changed

src/TensorFlowNET.Core/Eager/EagerRunner.RecordGradient.cs

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,11 @@ BackwardFunction GetGradientFunction(string op_name,
8080
Tensor[] op_outputs)
8181
=> (out_grads, unneeded_gradients) =>
8282
{
83+
if(!ops.gradientFunctions.ContainsKey(op_name))
84+
{
85+
throw new Exception($"gradientFunctions not find op_name: {op_name}");
86+
}
87+
8388
if (ops.gradientFunctions[op_name] == null)
8489
return new Tensor[op_inputs.Length];
8590

src/TensorFlowNET.Core/Gradients/nn_grad.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,37 @@ public static Tensor[] _Conv2DGrad(Operation op, Tensor[] grads)
229229
};
230230
}
231231

232+
/// <summary>
233+
/// Gradient function for Conv2D.
234+
/// </summary>
235+
/// <param name="op"></param>
236+
/// <param name="grads"></param>
237+
/// <returns></returns>
238+
[RegisterGradient("DepthwiseConv2dNative")]
239+
public static Tensor[] _DepthwiseConv2DGrad(Operation op, Tensor[] grads)
240+
{
241+
var dilations = op.get_attr_list<int>("dilations");
242+
var strides = op.get_attr_list<int>("strides");
243+
var padding = op.get_attr<string>("padding");
244+
var explicit_paddings = op.get_attr_list<int>("explicit_paddings");
245+
var data_format = op.get_attr<string>("data_format");
246+
var shape = gen_array_ops.shape_n(new Tensor[] { op.inputs[0], op.inputs[1] });
247+
248+
return new Tensor[]
249+
{
250+
gen_nn_ops.depthwise_conv2d_native_backprop_input(
251+
shape[0], op.inputs[1], grads[0],
252+
strides, padding, explicit_paddings,
253+
dilations: dilations,
254+
data_format: data_format),
255+
gen_nn_ops.depthwise_conv2d_native_backprop_filter(op.inputs[0], shape[1], grads[0],
256+
strides, padding,
257+
dilations: dilations,
258+
explicit_paddings: explicit_paddings,
259+
data_format: data_format)
260+
};
261+
}
262+
232263
[RegisterGradient("FusedBatchNorm")]
233264
public static Tensor[] _FusedBatchNormGrad(Operation op, Tensor[] grads)
234265
=> _BaseFusedBatchNormGrad(op, 0, grads);

src/TensorFlowNET.Core/Keras/Layers/ILayersApi.cs

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,19 @@ public ILayer Conv2D(int filters,
9595
bool use_bias = true,
9696
string kernel_initializer = "glorot_uniform",
9797
string bias_initializer = "zeros");
98+
public ILayer DepthwiseConv2D(Shape kernel_size = null,
99+
Shape strides = null,
100+
string padding = "valid",
101+
string data_format = null,
102+
Shape dilation_rate = null,
103+
int groups = 1,
104+
int depth_multiplier = 1,
105+
string activation = null,
106+
bool use_bias = false,
107+
string kernel_initializer = "glorot_uniform",
108+
string bias_initializer = "zeros",
109+
string depthwise_initializer = "glorot_uniform"
110+
);
98111

99112
public ILayer Dense(int units);
100113
public ILayer Dense(int units,

src/TensorFlowNET.Core/Tensors/tensor_util.cs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,9 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
249249
case sbyte val:
250250
tensor_proto.IntVal.AddRange(new[] { (int)val });
251251
break;
252+
case byte val:
253+
tensor_proto.IntVal.AddRange(new[] { (int)val });
254+
break;
252255
case int val:
253256
tensor_proto.IntVal.AddRange(new[] { val });
254257
break;
@@ -262,7 +265,7 @@ public static TensorProto make_tensor_proto(object values, TF_DataType dtype = T
262265
tensor_proto.DoubleVal.AddRange(new[] { val });
263266
break;
264267
default:
265-
throw new Exception("make_tensor_proto Not Implemented");
268+
throw new Exception($"make_tensor_proto Not Implemented {values.GetType().Name}");
266269
}
267270
}
268271

Lines changed: 167 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,167 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Text;
4+
using System;
5+
using Tensorflow.Keras.ArgsDefinition;
6+
using Tensorflow.Keras.Saving;
7+
using Tensorflow.Common.Types;
8+
using Tensorflow.Keras.Utils;
9+
using Tensorflow.Operations;
10+
using Newtonsoft.Json;
11+
using System.Security.Cryptography;
12+
13+
namespace Tensorflow.Keras.Layers
14+
{
15+
public class DepthwiseConv2DArgs: Conv2DArgs
16+
{
17+
/// <summary>
18+
/// depth_multiplier: The number of depthwise convolution output channels for
19+
/// each input channel.The total number of depthwise convolution output
20+
/// channels will be equal to `filters_in* depth_multiplier`.
21+
/// </summary>
22+
[JsonProperty("depth_multiplier")]
23+
public int DepthMultiplier { get; set; } = 1;
24+
25+
[JsonProperty("depthwise_initializer")]
26+
public IInitializer DepthwiseInitializer { get; set; }
27+
}
28+
29+
public class DepthwiseConv2D : Conv2D
30+
{
31+
/// <summary>
32+
/// depth_multiplier: The number of depthwise convolution output channels for
33+
/// each input channel.The total number of depthwise convolution output
34+
/// channels will be equal to `filters_in* depth_multiplier`.
35+
/// </summary>
36+
int DepthMultiplier = 1;
37+
38+
IInitializer DepthwiseInitializer;
39+
40+
int[] strides;
41+
42+
int[] dilation_rate;
43+
44+
string getDataFormat()
45+
{
46+
return data_format == "channels_first" ? "NCHW" : "NHWC";
47+
}
48+
49+
static int _id = 1;
50+
51+
public DepthwiseConv2D(DepthwiseConv2DArgs args):base(args)
52+
{
53+
args.Padding = args.Padding.ToUpper();
54+
55+
if(string.IsNullOrEmpty(args.Name))
56+
name = "DepthwiseConv2D_" + _id;
57+
58+
this.DepthMultiplier = args.DepthMultiplier;
59+
this.DepthwiseInitializer = args.DepthwiseInitializer;
60+
61+
}
62+
63+
public override void build(KerasShapesWrapper input_shape)
64+
{
65+
//base.build(input_shape);
66+
67+
var shape = input_shape.ToSingleShape();
68+
69+
int channel_axis = data_format == "channels_first" ? 1 : -1;
70+
var input_channel = channel_axis < 0 ?
71+
shape.dims[shape.ndim + channel_axis] :
72+
shape.dims[channel_axis];
73+
74+
var arg = args as DepthwiseConv2DArgs;
75+
76+
if (arg.Strides.ndim != shape.ndim)
77+
{
78+
if (arg.Strides.ndim == 2)
79+
{
80+
this.strides = new int[] { 1, (int)arg.Strides[0], (int)arg.Strides[1], 1 };
81+
}
82+
else
83+
{
84+
this.strides = conv_utils.normalize_tuple(new int[] { (int)arg.Strides[0] }, shape.ndim, "strides");
85+
}
86+
}
87+
else
88+
{
89+
this.strides = arg.Strides.dims.Select(o=>(int)(o)).ToArray();
90+
}
91+
92+
if (arg.DilationRate.ndim != shape.ndim)
93+
{
94+
this.dilation_rate = conv_utils.normalize_tuple(new int[] { (int)arg.DilationRate[0] }, shape.ndim, "dilation_rate");
95+
}
96+
97+
long channel_data = data_format == "channels_first" ? shape[0] : shape[shape.Length - 1];
98+
99+
var depthwise_kernel_shape = this.kernel_size.dims.concat(new long[] {
100+
channel_data,
101+
this.DepthMultiplier
102+
});
103+
104+
this.kernel = this.add_weight(
105+
shape: depthwise_kernel_shape,
106+
initializer: this.DepthwiseInitializer != null ? this.DepthwiseInitializer : this.kernel_initializer,
107+
name: "depthwise_kernel",
108+
trainable: true,
109+
dtype: DType,
110+
regularizer: this.kernel_regularizer
111+
);
112+
113+
var axes = new Dictionary<int, int>();
114+
axes.Add(-1, (int)input_channel);
115+
inputSpec = new InputSpec(min_ndim: rank + 2, axes: axes);
116+
117+
118+
if (use_bias)
119+
{
120+
bias = add_weight(name: "bias",
121+
shape: ((int)channel_data),
122+
initializer: bias_initializer,
123+
trainable: true,
124+
dtype: DType);
125+
}
126+
127+
built = true;
128+
_buildInputShape = input_shape;
129+
}
130+
131+
protected override Tensors Call(Tensors inputs, Tensors state = null,
132+
bool? training = false, IOptionalArgs? optional_args = null)
133+
{
134+
Tensor outputs = null;
135+
136+
outputs = gen_nn_ops.depthwise_conv2d_native(
137+
inputs,
138+
filter: this.kernel.AsTensor(),
139+
strides: this.strides,
140+
padding: this.padding,
141+
dilations: this.dilation_rate,
142+
data_format: this.getDataFormat(),
143+
name: name
144+
);
145+
146+
if (use_bias)
147+
{
148+
if (data_format == "channels_first")
149+
{
150+
throw new NotImplementedException("call channels_first");
151+
}
152+
else
153+
{
154+
outputs = gen_nn_ops.bias_add(outputs, ops.convert_to_tensor(bias),
155+
data_format: this.getDataFormat(), name: name);
156+
}
157+
}
158+
159+
if (activation != null)
160+
outputs = activation.Apply(outputs);
161+
162+
163+
return outputs;
164+
}
165+
166+
}
167+
}

src/TensorFlowNET.Keras/Layers/LayersApi.cs

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,38 @@ public ILayer Conv2D(int filters,
210210
Activation = keras.activations.GetActivationFromName(activation)
211211
});
212212

213+
public ILayer DepthwiseConv2D(Shape kernel_size = null,
214+
Shape strides = null,
215+
string padding = "valid",
216+
string data_format = null,
217+
Shape dilation_rate = null,
218+
int groups = 1,
219+
int depth_multiplier = 1,
220+
string activation = null,
221+
bool use_bias = false,
222+
string kernel_initializer = "glorot_uniform",
223+
string bias_initializer = "zeros",
224+
string depthwise_initializer = "glorot_uniform"
225+
)
226+
=> new DepthwiseConv2D(new DepthwiseConv2DArgs
227+
{
228+
Rank = 2,
229+
Filters = 1,
230+
KernelSize = (kernel_size == null) ? (5, 5) : kernel_size,
231+
Strides = strides == null ? (1) : strides,
232+
Padding = padding,
233+
DepthMultiplier = depth_multiplier,
234+
DataFormat = data_format,
235+
DilationRate = dilation_rate == null ? (1) : dilation_rate,
236+
Groups = groups,
237+
UseBias = use_bias,
238+
KernelInitializer = GetInitializerByName(kernel_initializer),
239+
DepthwiseInitializer = GetInitializerByName(depthwise_initializer == null ? kernel_initializer : depthwise_initializer),
240+
BiasInitializer = GetInitializerByName(bias_initializer),
241+
Activation = keras.activations.GetActivationFromName(activation),
242+
});
243+
244+
213245
/// <summary>
214246
/// Transposed convolution layer (sometimes called Deconvolution).
215247
/// </summary>

test/TensorFlowNET.Keras.UnitTest/EagerModeTestBase.cs

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,40 @@ public bool Equal(float[] f1, float[] f2)
3333
return ret;
3434
}
3535

36+
37+
public void AssertArray(int[] f1, int[] f2)
38+
{
39+
bool ret = false;
40+
for (var i = 0; i < f1.Length; i++)
41+
{
42+
ret = f1[i] == f2[i];
43+
if (!ret)
44+
break;
45+
}
46+
47+
if (!ret)
48+
{
49+
Assert.Fail($"Array not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]");
50+
}
51+
}
52+
53+
public void AssertArray(float[] f1, float[] f2)
54+
{
55+
bool ret = false;
56+
var tolerance = .00001f;
57+
for (var i = 0; i < f1.Length; i++)
58+
{
59+
ret = Math.Abs(f1[i] - f2[i]) <= tolerance;
60+
if (!ret)
61+
break;
62+
}
63+
64+
if (!ret)
65+
{
66+
Assert.Fail($"Array float not Equal:[{string.Join(",", f1)}] [{string.Join(",", f2)}]");
67+
}
68+
}
69+
3670
public bool Equal(double[] d1, double[] d2)
3771
{
3872
bool ret = false;

0 commit comments

Comments
 (0)