Skip to content

Commit 86eb48b

Browse files
authored
Merge pull request #1010 from AsakusaRinne/fix_mean_square_error_grad
Add shape deduce for mean square error grad.
2 parents 6838a51 + fc12978 commit 86eb48b

File tree

2 files changed

+29
-6
lines changed

2 files changed

+29
-6
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

+1-1
Original file line numberDiff line numberDiff line change
@@ -840,7 +840,7 @@ public static Tensor[] _PowGrad(Operation op, Tensor[] grads)
840840
/// <param name="x"></param>
841841
/// <param name="y"></param>
842842
/// <returns></returns>
843-
private static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
843+
public static (Tensor, Tensor, bool)[] SmartBroadcastGradientArgs(Tensor x, Tensor y, Tensor grad)
844844
{
845845
Tensor sx, sy;
846846
if (x.shape.IsFullyDefined &&

src/TensorFlowNET.Core/Gradients/nn_grad.cs

+28-5
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515
******************************************************************************/
1616

1717
using System;
18+
using System.Diagnostics;
1819
using System.Linq;
1920
using Tensorflow.Operations;
2021
using static Tensorflow.Binding;
@@ -135,13 +136,35 @@ public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
135136
{
136137
Tensor x = op.inputs[0];
137138
Tensor y = op.inputs[1];
139+
var grad = grads[0];
138140
var scale = ops.convert_to_tensor(2.0f, dtype: x.dtype);
139-
var x_grad = math_ops.scalar_mul(scale, grads[0]) * (x - y);
140-
return new Tensor[]
141+
var x_grad = math_ops.scalar_mul(scale, grad) * (x - y);
142+
if (math_grad._ShapesFullySpecifiedAndEqual(x, y, grad))
141143
{
142-
x_grad,
143-
-x_grad
144-
};
144+
return new Tensor[] { x_grad, -x_grad };
145+
}
146+
var broadcast_info = math_grad.SmartBroadcastGradientArgs(x, y, grad);
147+
Debug.Assert(broadcast_info.Length == 2);
148+
var (sx, rx, must_reduce_x) = broadcast_info[0];
149+
var (sy, ry, must_reduce_y) = broadcast_info[1];
150+
Tensor gx, gy;
151+
if (must_reduce_x)
152+
{
153+
gx = array_ops.reshape(math_ops.reduce_sum(x_grad, rx), sx);
154+
}
155+
else
156+
{
157+
gx = x_grad;
158+
}
159+
if (must_reduce_y)
160+
{
161+
gy = -array_ops.reshape(math_ops.reduce_sum(x_grad, ry), sy);
162+
}
163+
else
164+
{
165+
gy = -x_grad;
166+
}
167+
return new Tensor[] { gx, gy };
145168
}
146169

147170
/// <summary>

0 commit comments

Comments
 (0)