Skip to content

fix: add EinsumGrad #1132

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jul 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 131 additions & 0 deletions src/TensorFlowNET.Core/Gradients/math_grad.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,137 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads)
};
}

public static string ellipsis = "...";
[RegisterGradient("Einsum")]
public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads)
{
// Gradient for Einsum.
string equation = (string)op.get_attr("equation");
string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None);
var input_subs = split_equation[0];
var output_subs = split_equation[1];

if (op.inputs.Length == 1)
{
var input_shape = array_ops.shape(op.inputs[0]);
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + ellipsis)));
if (reduced_label_set.Count == 0)
return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) };
return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) };
}

string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None);
var x_subs = split_input_subs[0];
var y_subs = split_input_subs[1];
// Add ellipsis for broadcasted dimensions if any operand does not have it.
// This is because the equation "...ij,jk->ik" may be valid if the 0th input's
// batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid
// because only the output subscripts contain ellipsis.
if (output_subs.Contains(ellipsis))
{
if (!x_subs.Contains(ellipsis))
x_subs += ellipsis;
if (!y_subs.Contains(ellipsis))
y_subs += ellipsis;
}
// Obtain the gradients wrt the inputs x and y, without taking into account
// the unbroadcasting.
var x = op.inputs[0];
var y = op.inputs[1];
if (grads.GetDataType().is_complex())
{
x = math_ops.conj(x);
y = math_ops.conj(y);
}

var x_shape = array_ops.shape(x);
var y_shape = array_ops.shape(y);
var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs);
var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs);

if (!output_subs.Contains(ellipsis))
return new Tensor[] { grad_x, grad_y };
var bx = _GetBcastSubshape(x_subs);
int bx_start = bx[0], bx_end = bx[1];
var by = _GetBcastSubshape(y_subs);
int by_start = by[0], by_end = by[1];

var x_shape_static = x.shape;
var y_shape_static = y.shape;
if(x_shape_static.IsFullyDefined &&
y_shape_static.IsFullyDefined &&
x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)])
return new Tensor[] { grad_x, grad_y };

var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)],
y_shape[string.Format("{0}:{1}", by_start, by_end)]);
var rx = r[0];
var ry = r[1];
grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape);
grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape);
return new Tensor[] { grad_x, grad_y };
}
protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape,
string input_subs, string other_subs, string output_subs)
{
var reduced_label_set = new HashSet<char>(new HashSet<char>(input_subs).Except(new HashSet<char>(output_subs + other_subs + ".")));
var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));
var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand));
if (reduced_label_set.Count == 0)
return grad_reduced;
return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set);
}
protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet<char> reduced_label_set)
{
string reduced_subs;
Tensor reduced_dims;
List<int> reduced_axes;
_GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes);
bool has_repeated_labels = (
new HashSet<char>(input_subs).Count + new HashSet<char>(output_subs).Count <
input_subs.Length + output_subs.Length);
var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s)));

if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs)
{
var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes));
return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape);
}
else
{
var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0);
var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0);
var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels);
return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad));
}
}
protected static void _GetReducedSubscripts(HashSet<char> reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List<int> reduced_axes)
{
reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString()));
reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList();
reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList());
}
protected static int _GetAxisFromLabel(string subscripts, char label)
{
var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None);
var index = splits[0].IndexOf(label);
if (index != -1) return index;
if (splits.Length < 2) throw new OutOfRangeError();
index = splits[1].IndexOf(label);
if (index != -1) return index;
throw new ValueError();
}
protected static int[] _GetBcastSubshape(string subscripts)
{
int start = subscripts.IndexOf(ellipsis);
if (start == -1) return new int[] { 0, 0 };
int remaining = subscripts.Length - (start + ellipsis.Length);
int end;
if (remaining > 0) end = remaining;
else throw new Exception();
return new int[] { start, end };
}

/// <summary>
/// Returns grad * exp(x).
/// </summary>
Expand Down