Skip to content

Commit a5fdaf4

Browse files
authored
Merge pull request #1155 from Wanglongzhi2001/master
fix: revise np.amin, np.amax and add np.argmin
2 parents 5e60a13 + 482899e commit a5fdaf4

File tree

3 files changed

+9
-2
lines changed

3 files changed

+9
-2
lines changed

src/TensorFlowNET.Core/NumPy/NumPy.Sorting.Searching.Counting.cs

+4
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,10 @@ public partial class np
1313
public static NDArray argmax(NDArray a, Axis? axis = null)
1414
=> new NDArray(math_ops.argmax(a, axis ?? 0));
1515

16+
[AutoNumPy]
17+
public static NDArray argmin(NDArray a, Axis? axis = null)
18+
=> new NDArray(math_ops.argmin(a, axis ?? 0));
19+
1620
[AutoNumPy]
1721
public static NDArray argsort(NDArray a, Axis? axis = null)
1822
=> new NDArray(sort_ops.argsort(a, axis: axis ?? -1));

src/TensorFlowNET.Core/NumPy/NumPy.Statistics.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,10 @@ namespace Tensorflow.NumPy
1010
public partial class np
1111
{
1212
[AutoNumPy]
13-
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.arg_min(x, axis));
13+
public static NDArray amin(NDArray x, int axis = 0) => new NDArray(tf.min(x, axis));
1414

1515
[AutoNumPy]
16-
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.math.argmax(x, axis));
16+
public static NDArray amax(NDArray x, int axis = 0) => new NDArray(tf.max(x, axis));
1717

1818
[AutoNumPy]
1919
public static NDArray average(NDArray a, int axis = -1, NDArray? weights = null, bool returned = false)

src/TensorFlowNET.Core/Operations/math_ops.cs

+3
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ public static Tensor add_n(Tensor[] inputs, string name = null)
7777
public static Tensor argmax(Tensor input, Axis dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
7878
=> gen_math_ops.arg_max(input, dimension, output_type: output_type, name: name);
7979

80+
public static Tensor argmin(Tensor input, Axis dimension, TF_DataType output_type = TF_DataType.TF_INT64, string name = null)
81+
=> gen_math_ops.arg_min(input, dimension, output_type: output_type, name: name);
82+
8083
public static Tensor round(Tensor x, string name = null)
8184
{
8285
x = ops.convert_to_tensor(x, name: "x");

0 commit comments

Comments
 (0)