Skip to content

Commit fac48f5

Browse files
authored
Merge pull request #1132 from lingbai-kong/EinsumGrad
fix: add EinsumGrad
2 parents adc90af + f026963 commit fac48f5

File tree

1 file changed

+131
-0
lines changed

1 file changed

+131
-0
lines changed

src/TensorFlowNET.Core/Gradients/math_grad.cs

+131
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,137 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
117117
};
118118
}
119119

120+
public static string ellipsis = "...";
121+
[RegisterGradient("Einsum")]
122+
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads)
123+
{
124+
// Gradient for Einsum.
125+
string equation = (string)op.get_attr("equation");
126+
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None);
127+
var input_subs = split_equation[0];
128+
var output_subs = split_equation[1];
129+
130+
if (op.inputs.Length == 1)
131+
{
132+
var input_shape = array_ops.shape(op.inputs[0]);
133+
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis)));
134+
if (reduced_label_set.Count == 0)
135+
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) };
136+
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) };
137+
}
138+
139+
string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None);
140+
var x_subs = split_input_subs[0];
141+
var y_subs = split_input_subs[1];
142+
// Add ellipsis for broadcasted dimensions if any operand does not have it.
143+
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's
144+
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
145+
// because only the output subscripts contain ellipsis.
146+
if (output_subs.Contains(ellipsis))
147+
{
148+
if (!x_subs.Contains(ellipsis))
149+
x_subs += ellipsis;
150+
if (!y_subs.Contains(ellipsis))
151+
y_subs += ellipsis;
152+
}
153+
// Obtain the gradients wrt the inputs x and y, without taking into account
154+
// the unbroadcasting.
155+
var x = op.inputs[0];
156+
var y = op.inputs[1];
157+
if (grads.GetDataType().is_complex())
158+
{
159+
x = math_ops.conj(x);
160+
y = math_ops.conj(y);
161+
}
162+
163+
var x_shape = array_ops.shape(x);
164+
var y_shape = array_ops.shape(y);
165+
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs);
166+
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs);
167+
168+
if (!output_subs.Contains(ellipsis))
169+
return new Tensor[] { grad_x, grad_y };
170+
var bx = _GetBcastSubshape(x_subs);
171+
int bx_start = bx[0], bx_end = bx[1];
172+
var by = _GetBcastSubshape(y_subs);
173+
int by_start = by[0], by_end = by[1];
174+
175+
var x_shape_static = x.shape;
176+
var y_shape_static = y.shape;
177+
if(x_shape_static.IsFullyDefined &&
178+
y_shape_static.IsFullyDefined &&
179+
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)])
180+
return new Tensor[] { grad_x, grad_y };
181+
182+
var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)],
183+
y_shape[string.Format("{0}:{1}", by_start, by_end)]);
184+
var rx = r[0];
185+
var ry = r[1];
186+
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape);
187+
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape);
188+
return new Tensor[] { grad_x, grad_y };
189+
}
190+
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape,
191+
string input_subs, string other_subs, string output_subs)
192+
{
193+
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + ".")));
194+
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));
195+
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand));
196+
if (reduced_label_set.Count == 0)
197+
return grad_reduced;
198+
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set);
199+
}
200+
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set)
201+
{
202+
string reduced_subs;
203+
Tensor reduced_dims;
204+
List<int> reduced_axes;
205+
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes);
206+
bool has_repeated_labels = (
207+
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count <
208+
input_subs.Length + output_subs.Length);
209+
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));
210+
211+
if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs)
212+
{
213+
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes));
214+
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape);
215+
}
216+
else
217+
{
218+
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0);
219+
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0);
220+
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels);
221+
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad));
222+
}
223+
}
224+
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes)
225+
{
226+
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString()));
227+
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList();
228+
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList());
229+
}
230+
protected static int _GetAxisFromLabel(string subscripts, char label)
231+
{
232+
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None);
233+
var index = splits[0].IndexOf(label);
234+
if (index != -1) return index;
235+
if (splits.Length < 2) throw new OutOfRangeError();
236+
index = splits[1].IndexOf(label);
237+
if (index != -1) return index;
238+
throw new ValueError();
239+
}
240+
protected static int[] _GetBcastSubshape(string subscripts)
241+
{
242+
int start = subscripts.IndexOf(ellipsis);
243+
if (start == -1) return new int[] { 0, 0 };
244+
int remaining = subscripts.Length - (start + ellipsis.Length);
245+
int end;
246+
if (remaining > 0) end = remaining;
247+
else throw new Exception();
248+
return new int[] { start, end };
249+
}
250+
120251
/// <summary>
121252
/// Returns grad * exp(x).
122253
/// </summary>

0 commit comments

Comments
 (0)