From f026963a7da0cf8444f05fd4abce8962b5848c62 Mon Sep 17 00:00:00 2001 From: lingbai-kong Date: Wed, 5 Jul 2023 21:46:50 +0800 Subject: [PATCH] add EinsumGrad --- src/TensorFlowNET.Core/Gradients/math_grad.cs | 131 ++++++++++++++++++ 1 file changed, 131 insertions(+) diff --git a/src/TensorFlowNET.Core/Gradients/math_grad.cs b/src/TensorFlowNET.Core/Gradients/math_grad.cs index be1fbbba7..8c3f0f8bd 100644 --- a/src/TensorFlowNET.Core/Gradients/math_grad.cs +++ b/src/TensorFlowNET.Core/Gradients/math_grad.cs @@ -117,6 +117,137 @@ public static Tensor[] _DivNoNanGrad(Operation op, Tensor[] grads) }; } + public static string ellipsis = "..."; + [RegisterGradient("Einsum")] + public static Tensor[] _EinsumGrad(Operation op, Tensor[] grads) + { + // Gradient for Einsum. + string equation = (string)op.get_attr("equation"); + string[] split_equation = equation.Split(new string[] { "->" }, StringSplitOptions.None); + var input_subs = split_equation[0]; + var output_subs = split_equation[1]; + + if (op.inputs.Length == 1) + { + var input_shape = array_ops.shape(op.inputs[0]); + var reduced_label_set = new HashSet(new HashSet(input_subs).Except(new HashSet(output_subs + ellipsis))); + if (reduced_label_set.Count == 0) + return new Tensor[] { math_ops.einsum(string.Format("{0}->{1}", output_subs, input_subs), new Tensors(grads)) }; + return new Tensor[] { _GetGradReduced(new Tensors(grads), output_subs, input_subs, input_shape, reduced_label_set) }; + } + + string[] split_input_subs = input_subs.Split(new string[] { "," }, StringSplitOptions.None); + var x_subs = split_input_subs[0]; + var y_subs = split_input_subs[1]; + // Add ellipsis for broadcasted dimensions if any operand does not have it. + // This is because the equation "...ij,jk->ik" may be valid if the 0th input's + // batch shape is empty, but the VJP equation "jk,ik->...ij" is not valid + // because only the output subscripts contain ellipsis. + if (output_subs.Contains(ellipsis)) + { + if (!x_subs.Contains(ellipsis)) + x_subs += ellipsis; + if (!y_subs.Contains(ellipsis)) + y_subs += ellipsis; + } + // Obtain the gradients wrt the inputs x and y, without taking into account + // the unbroadcasting. + var x = op.inputs[0]; + var y = op.inputs[1]; + if (grads.GetDataType().is_complex()) + { + x = math_ops.conj(x); + y = math_ops.conj(y); + } + + var x_shape = array_ops.shape(x); + var y_shape = array_ops.shape(y); + var grad_x = _GetGradWrt(grads, y, x_shape, x_subs, y_subs, output_subs); + var grad_y = _GetGradWrt(grads, x, y_shape, y_subs, x_subs, output_subs); + + if (!output_subs.Contains(ellipsis)) + return new Tensor[] { grad_x, grad_y }; + var bx = _GetBcastSubshape(x_subs); + int bx_start = bx[0], bx_end = bx[1]; + var by = _GetBcastSubshape(y_subs); + int by_start = by[0], by_end = by[1]; + + var x_shape_static = x.shape; + var y_shape_static = y.shape; + if(x_shape_static.IsFullyDefined && + y_shape_static.IsFullyDefined && + x_shape_static[string.Format("{0}:{1}",bx_start,bx_end)] == y_shape_static[string.Format("{0}:{1}", by_start, by_end)]) + return new Tensor[] { grad_x, grad_y }; + + var r = gen_array_ops.broadcast_gradient_args(x_shape[string.Format("{0}:{1}", bx_start, bx_end)], + y_shape[string.Format("{0}:{1}", by_start, by_end)]); + var rx = r[0]; + var ry = r[1]; + grad_x = array_ops.reshape(math_ops.reduce_sum(grad_x, bx_start + rx), x_shape); + grad_y = array_ops.reshape(math_ops.reduce_sum(grad_y, by_start + ry), y_shape); + return new Tensor[] { grad_x, grad_y }; + } + protected static Tensor _GetGradWrt(Tensor[] output_grads, Tensor other_operand, Tensor input_shape, + string input_subs, string other_subs, string output_subs) + { + var reduced_label_set = new HashSet(new HashSet(input_subs).Except(new HashSet(output_subs + other_subs + "."))); + var left_subs = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); + var grad_reduced = math_ops.einsum(string.Format("{0},{1}->{2}", output_subs, other_subs, left_subs), new Tensors((Tensors)output_grads, other_operand)); + if (reduced_label_set.Count == 0) + return grad_reduced; + return _GetGradReduced(grad_reduced, left_subs, input_subs, input_shape, reduced_label_set); + } + protected static Tensor _GetGradReduced(Tensor output_grad, string output_subs, string input_subs, Tensor input_shape, HashSet reduced_label_set) + { + string reduced_subs; + Tensor reduced_dims; + List reduced_axes; + _GetReducedSubscripts(reduced_label_set, input_shape, input_subs, out reduced_subs, out reduced_dims, out reduced_axes); + bool has_repeated_labels = ( + new HashSet(input_subs).Count + new HashSet(output_subs).Count < + input_subs.Length + output_subs.Length); + var input_subs_without_reduced_labels = string.Join("", input_subs.Where(s => !reduced_label_set.Contains(s))); + + if (!has_repeated_labels && input_subs_without_reduced_labels == output_subs) + { + var reduced_shape = math_ops.reduced_shape(input_shape, ops.convert_to_tensor(reduced_axes)); + return gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), input_shape); + } + else + { + var grad_shape_with_reduced_labels = array_ops.concat(new Tensor[] { reduced_dims, array_ops.shape(new Tensors(output_grad)) }, axis: 0); + var reduced_shape = array_ops.concat(new Tensor[] { array_ops.ones(reduced_label_set.Count, dtype: dtypes.int32), array_ops.shape(new Tensors(output_grad)) }, axis: 0); + var broadcasted_grad = gen_array_ops.broadcast_to(array_ops.reshape(output_grad, reduced_shape), grad_shape_with_reduced_labels); + return math_ops.einsum(string.Format("{0}->{1}", reduced_subs + output_subs, input_subs), new Tensors(broadcasted_grad)); + } + } + protected static void _GetReducedSubscripts(HashSet reduced_label_set, Tensor input_shape, string subscripts, out string reduced_subs, out Tensor reduced_dims, out List reduced_axes) + { + reduced_subs = string.Join("", reduced_label_set.Select(c => c.ToString())); + reduced_axes = reduced_subs.Select(s => _GetAxisFromLabel(subscripts, s)).ToList(); + reduced_dims = array_ops.stack(reduced_axes.Select(ax => input_shape[ax]).ToList()); + } + protected static int _GetAxisFromLabel(string subscripts, char label) + { + var splits = subscripts.Split(new string[] { ellipsis }, StringSplitOptions.None); + var index = splits[0].IndexOf(label); + if (index != -1) return index; + if (splits.Length < 2) throw new OutOfRangeError(); + index = splits[1].IndexOf(label); + if (index != -1) return index; + throw new ValueError(); + } + protected static int[] _GetBcastSubshape(string subscripts) + { + int start = subscripts.IndexOf(ellipsis); + if (start == -1) return new int[] { 0, 0 }; + int remaining = subscripts.Length - (start + ellipsis.Length); + int end; + if (remaining > 0) end = remaining; + else throw new Exception(); + return new int[] { start, end }; + } + /// /// Returns grad * exp(x). ///