@@ -117,6 +117,137 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
117
117
} ;
118
118
}
119
119
120
+ public static string ellipsis = "..." ;
121
+ [ RegisterGradient ( "Einsum" ) ]
122
+ public static Tensor [ ] _EinsumGrad ( Operation op , Tensor [ ] grads )
123
+ {
124
+ // Gradient for Einsum.
125
+ string equation = ( string ) op . get_attr ( "equation" ) ;
126
+ string [ ] split_equation = equation . Split ( new string [ ] { "->" } , StringSplitOptions . None ) ;
127
+ var input_subs = split_equation [ 0 ] ;
128
+ var output_subs = split_equation [ 1 ] ;
129
+
130
+ if ( op . inputs . Length == 1 )
131
+ {
132
+ var input_shape = array_ops . shape ( op . inputs [ 0 ] ) ;
133
+ var reduced_label_set = new HashSet < char > ( new HashSet < char > ( input_subs ) . Except ( new HashSet < char > ( output_subs + ellipsis ) ) ) ;
134
+ if ( reduced_label_set . Count == 0 )
135
+ return new Tensor [ ] { math_ops . einsum ( string . Format ( "{0}->{1}" , output_subs , input_subs ) , new Tensors ( grads ) ) } ;
136
+ return new Tensor [ ] { _GetGradReduced ( new Tensors ( grads ) , output_subs , input_subs , input_shape , reduced_label_set ) } ;
137
+ }
138
+
139
+ string [ ] split_input_subs = input_subs . Split ( new string [ ] { "," } , StringSplitOptions . None ) ;
140
+ var x_subs = split_input_subs [ 0 ] ;
141
+ var y_subs = split_input_subs [ 1 ] ;
142
+ // Add ellipsis for broadcasted dimensions if any operand does not have it.
143
+ // This is because the equation "...ij,jk->ik" may be valid if the 0th input's
144
+ // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
145
+ // because only the output subscripts contain ellipsis.
146
+ if ( output_subs . Contains ( ellipsis ) )
147
+ {
148
+ if ( ! x_subs . Contains ( ellipsis ) )
149
+ x_subs += ellipsis ;
150
+ if ( ! y_subs . Contains ( ellipsis ) )
151
+ y_subs += ellipsis ;
152
+ }
153
+ // Obtain the gradients wrt the inputs x and y, without taking into account
154
+ // the unbroadcasting.
155
+ var x = op . inputs [ 0 ] ;
156
+ var y = op . inputs [ 1 ] ;
157
+ if ( grads . GetDataType ( ) . is_complex ( ) )
158
+ {
159
+ x = math_ops . conj ( x ) ;
160
+ y = math_ops . conj ( y ) ;
161
+ }
162
+
163
+ var x_shape = array_ops . shape ( x ) ;
164
+ var y_shape = array_ops . shape ( y ) ;
165
+ var grad_x = _GetGradWrt ( grads , y , x_shape , x_subs , y_subs , output_subs ) ;
166
+ var grad_y = _GetGradWrt ( grads , x , y_shape , y_subs , x_subs , output_subs ) ;
167
+
168
+ if ( ! output_subs . Contains ( ellipsis ) )
169
+ return new Tensor [ ] { grad_x , grad_y } ;
170
+ var bx = _GetBcastSubshape ( x_subs ) ;
171
+ int bx_start = bx [ 0 ] , bx_end = bx [ 1 ] ;
172
+ var by = _GetBcastSubshape ( y_subs ) ;
173
+ int by_start = by [ 0 ] , by_end = by [ 1 ] ;
174
+
175
+ var x_shape_static = x . shape ;
176
+ var y_shape_static = y . shape ;
177
+ if ( x_shape_static . IsFullyDefined &&
178
+ y_shape_static . IsFullyDefined &&
179
+ x_shape_static [ string . Format ( "{0}:{1}" , bx_start , bx_end ) ] == y_shape_static [ string . Format ( "{0}:{1}" , by_start , by_end ) ] )
180
+ return new Tensor [ ] { grad_x , grad_y } ;
181
+
182
+ var r = gen_array_ops . broadcast_gradient_args ( x_shape [ string . Format ( "{0}:{1}" , bx_start , bx_end ) ] ,
183
+ y_shape [ string . Format ( "{0}:{1}" , by_start , by_end ) ] ) ;
184
+ var rx = r [ 0 ] ;
185
+ var ry = r [ 1 ] ;
186
+ grad_x = array_ops . reshape ( math_ops . reduce_sum ( grad_x , bx_start + rx ) , x_shape ) ;
187
+ grad_y = array_ops . reshape ( math_ops . reduce_sum ( grad_y , by_start + ry ) , y_shape ) ;
188
+ return new Tensor [ ] { grad_x , grad_y } ;
189
+ }
190
+ protected static Tensor _GetGradWrt ( Tensor [ ] output_grads , Tensor other_operand , Tensor input_shape ,
191
+ string input_subs , string other_subs , string output_subs )
192
+ {
193
+ var reduced_label_set = new HashSet < char > ( new HashSet < char > ( input_subs ) . Except ( new HashSet < char > ( output_subs + other_subs + "." ) ) ) ;
194
+ var left_subs = string . Join ( "" , input_subs . Where ( s => ! reduced_label_set . Contains ( s ) ) ) ;
195
+ var grad_reduced = math_ops . einsum ( string . Format ( "{0},{1}->{2}" , output_subs , other_subs , left_subs ) , new Tensors ( ( Tensors ) output_grads , other_operand ) ) ;
196
+ if ( reduced_label_set . Count == 0 )
197
+ return grad_reduced ;
198
+ return _GetGradReduced ( grad_reduced , left_subs , input_subs , input_shape , reduced_label_set ) ;
199
+ }
200
+ protected static Tensor _GetGradReduced ( Tensor output_grad , string output_subs , string input_subs , Tensor input_shape , HashSet < char > reduced_label_set )
201
+ {
202
+ string reduced_subs ;
203
+ Tensor reduced_dims ;
204
+ List < int > reduced_axes ;
205
+ _GetReducedSubscripts ( reduced_label_set , input_shape , input_subs , out reduced_subs , out reduced_dims , out reduced_axes ) ;
206
+ bool has_repeated_labels = (
207
+ new HashSet < char > ( input_subs ) . Count + new HashSet < char > ( output_subs ) . Count <
208
+ input_subs . Length + output_subs . Length ) ;
209
+ var input_subs_without_reduced_labels = string . Join ( "" , input_subs . Where ( s => ! reduced_label_set . Contains ( s ) ) ) ;
210
+
211
+ if ( ! has_repeated_labels && input_subs_without_reduced_labels == output_subs )
212
+ {
213
+ var reduced_shape = math_ops . reduced_shape ( input_shape , ops . convert_to_tensor ( reduced_axes ) ) ;
214
+ return gen_array_ops . broadcast_to ( array_ops . reshape ( output_grad , reduced_shape ) , input_shape ) ;
215
+ }
216
+ else
217
+ {
218
+ var grad_shape_with_reduced_labels = array_ops . concat ( new Tensor [ ] { reduced_dims , array_ops . shape ( new Tensors ( output_grad ) ) } , axis : 0 ) ;
219
+ var reduced_shape = array_ops . concat ( new Tensor [ ] { array_ops . ones ( reduced_label_set . Count , dtype : dtypes . int32 ) , array_ops . shape ( new Tensors ( output_grad ) ) } , axis : 0 ) ;
220
+ var broadcasted_grad = gen_array_ops . broadcast_to ( array_ops . reshape ( output_grad , reduced_shape ) , grad_shape_with_reduced_labels ) ;
221
+ return math_ops . einsum ( string . Format ( "{0}->{1}" , reduced_subs + output_subs , input_subs ) , new Tensors ( broadcasted_grad ) ) ;
222
+ }
223
+ }
224
+ protected static void _GetReducedSubscripts ( HashSet < char > reduced_label_set , Tensor input_shape , string subscripts , out string reduced_subs , out Tensor reduced_dims , out List < int > reduced_axes )
225
+ {
226
+ reduced_subs = string . Join ( "" , reduced_label_set . Select ( c => c . ToString ( ) ) ) ;
227
+ reduced_axes = reduced_subs . Select ( s => _GetAxisFromLabel ( subscripts , s ) ) . ToList ( ) ;
228
+ reduced_dims = array_ops . stack ( reduced_axes . Select ( ax => input_shape [ ax ] ) . ToList ( ) ) ;
229
+ }
230
+ protected static int _GetAxisFromLabel ( string subscripts , char label )
231
+ {
232
+ var splits = subscripts . Split ( new string [ ] { ellipsis } , StringSplitOptions . None ) ;
233
+ var index = splits [ 0 ] . IndexOf ( label ) ;
234
+ if ( index != - 1 ) return index ;
235
+ if ( splits . Length < 2 ) throw new OutOfRangeError ( ) ;
236
+ index = splits [ 1 ] . IndexOf ( label ) ;
237
+ if ( index != - 1 ) return index ;
238
+ throw new ValueError ( ) ;
239
+ }
240
+ protected static int [ ] _GetBcastSubshape ( string subscripts )
241
+ {
242
+ int start = subscripts . IndexOf ( ellipsis ) ;
243
+ if ( start == - 1 ) return new int [ ] { 0 , 0 } ;
244
+ int remaining = subscripts . Length - ( start + ellipsis . Length ) ;
245
+ int end ;
246
+ if ( remaining > 0 ) end = remaining ;
247
+ else throw new Exception ( ) ;
248
+ return new int [ ] { start , end } ;
249
+ }
250
+
120
251
/// <summary>
121
252
/// Returns grad * exp(x).
122
253
/// </summary>
0 commit comments