@@ -15,6 +15,7 @@ limitations under the License.
15
15
******************************************************************************/
16
16
17
17
using System ;
18
+ using System . Diagnostics ;
18
19
using System . Linq ;
19
20
using Tensorflow . Operations ;
20
21
using static Tensorflow . Binding ;
@@ -135,13 +136,35 @@ public static Tensor[] _SquaredDifferenceGrad(Operation op, Tensor[] grads)
135
136
{
136
137
Tensor x = op . inputs [ 0 ] ;
137
138
Tensor y = op . inputs [ 1 ] ;
139
+ var grad = grads [ 0 ] ;
138
140
var scale = ops . convert_to_tensor ( 2.0f , dtype : x . dtype ) ;
139
- var x_grad = math_ops . scalar_mul ( scale , grads [ 0 ] ) * ( x - y ) ;
140
- return new Tensor [ ]
141
+ var x_grad = math_ops . scalar_mul ( scale , grad ) * ( x - y ) ;
142
+ if ( math_grad . _ShapesFullySpecifiedAndEqual ( x , y , grad ) )
141
143
{
142
- x_grad ,
143
- - x_grad
144
- } ;
144
+ return new Tensor [ ] { x_grad , - x_grad } ;
145
+ }
146
+ var broadcast_info = math_grad . SmartBroadcastGradientArgs ( x , y , grad ) ;
147
+ Debug . Assert ( broadcast_info . Length == 2 ) ;
148
+ var ( sx , rx , must_reduce_x ) = broadcast_info [ 0 ] ;
149
+ var ( sy , ry , must_reduce_y ) = broadcast_info [ 1 ] ;
150
+ Tensor gx , gy ;
151
+ if ( must_reduce_x )
152
+ {
153
+ gx = array_ops . reshape ( math_ops . reduce_sum ( x_grad , rx ) , sx ) ;
154
+ }
155
+ else
156
+ {
157
+ gx = x_grad ;
158
+ }
159
+ if ( must_reduce_y )
160
+ {
161
+ gy = - array_ops . reshape ( math_ops . reduce_sum ( x_grad , ry ) , sy ) ;
162
+ }
163
+ else
164
+ {
165
+ gy = - x_grad ;
166
+ }
167
+ return new Tensor [ ] { gx , gy } ;
145
168
}
146
169
147
170
/// <summary>
0 commit comments