From 482899eab734f1b6f3a39ef52a4f9ae28e332ed5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E2=80=9CWanglongzhi2001=E2=80=9D?= <“583087864@qq.com”> Date: Sat, 22 Jul 2023 15:03:50 +0800 Subject: [PATCH] fix: revise np.amin, np.amax and add np.argmin --- .../NumPy/NumPy.Sorting.Searching.Counting.cs | 4 ++++ src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs | 4 ++-- src/TensorFlowNET.Core/Operations/math_ops.cs | 3 +++ 3 files changed, 9 insertions(+), 2 deletions(-) diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs index 5182d5726..4cad36e0b 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs @@ -13,6 +13,10 @@ public partial class np public static NDArray argmax(NDArray a, Axis? axis = null) => new NDArray(math_ops.argmax(a, axis ?? 0)); + [AutoNumPy] + public static NDArray argmin(NDArray a, Axis? axis = null) + => new NDArray(math_ops.argmin(a, axis ?? 0)); + [AutoNumPy] public static NDArray argsort(NDArray a, Axis? axis = null) => new NDArray(sort_ops.argsort(a, axis: axis ?? -1)); diff --git a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs index 5d86b1b39..bce16ec9f 100644 --- a/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs +++ b/src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs @@ -10,10 +10,10 @@ namespace Tensorflow.NumPy public partial class np { [AutoNumPy] - public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis)); + public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis)); [AutoNumPy] - public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis)); + public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis)); [AutoNumPy] public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false) diff --git a/src/TensorFlowNET.Core/Operations/math_ops.cs b/src/TensorFlowNET.Core/Operations/math_ops.cs index 092137bf2..e77df702f 100644 --- a/src/TensorFlowNET.Core/Operations/math_ops.cs +++ b/src/TensorFlowNET.Core/Operations/math_ops.cs @@ -77,6 +77,9 @@ public static Tensor add_n(Tensor[] inputs, string name = null) public static Tensor argmax(Tensor input, Axis dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) => gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name); + public static Tensor argmin(Tensor input, Axis dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null) + => gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name); + public static Tensor round(Tensor x, string name = null) { x = ops.convert_to_tensor(x, name: "x");