Skip to content

Commit 79a9363

Browse files
authored
Merge pull request #1011 from BalashovK/master
Added: complex, real, imag, angle
2 parents ccc556d + febba7b commit 79a9363

File tree

5 files changed

+242
-43
lines changed

5 files changed

+242
-43
lines changed

src/TensorFlowNET.Core/APIs/tf.math.cs

+15-3
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*****************************************************************************
2-
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
2+
Copyright 2023 The TensorFlow.NET Authors. All Rights Reserved.
33
44
Licensed under the Apache License, Version 2.0 (the "License");
55
you may not use this file except in compliance with the License.
@@ -57,7 +57,7 @@ public Tensor softplus(Tensor features, string name = null)
5757

5858
public Tensor tanh(Tensor x, string name = null)
5959
=> math_ops.tanh(x, name: name);
60-
60+
6161
/// <summary>
6262
/// Finds values and indices of the `k` largest entries for the last dimension.
6363
/// </summary>
@@ -93,6 +93,16 @@ public Tensor bincount(Tensor arr, Tensor weights = null,
9393
bool binary_output = false)
9494
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
9595
dtype: dtype, name: name, axis: axis, binary_output: binary_output);
96+
97+
public Tensor real(Tensor x, string name = null)
98+
=> gen_ops.real(x, x.dtype.real_dtype(), name);
99+
public Tensor imag(Tensor x, string name = null)
100+
=> gen_ops.imag(x, x.dtype.real_dtype(), name);
101+
102+
public Tensor conj(Tensor x, string name = null)
103+
=> gen_ops.conj(x, name);
104+
public Tensor angle(Tensor x, string name = null)
105+
=> gen_ops.angle(x, x.dtype.real_dtype(), name);
96106
}
97107

98108
public Tensor abs(Tensor x, string name = null)
@@ -537,7 +547,7 @@ public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims
537547
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null,
538548
bool keepdims = false, string name = null)
539549
{
540-
if(keepdims)
550+
if (keepdims)
541551
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name);
542552
else
543553
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices));
@@ -585,5 +595,7 @@ public Tensor square(Tensor x, string name = null)
585595
=> gen_math_ops.square(x, name: name);
586596
public Tensor squared_difference(Tensor x, Tensor y, string name = null)
587597
=> gen_math_ops.squared_difference(x: x, y: y, name: name);
598+
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
599+
string name = null) => gen_ops.complex(real, imag, dtype, name);
588600
}
589601
}

src/TensorFlowNET.Core/Operations/gen_ops.cs

+20-38
Original file line numberDiff line numberDiff line change
@@ -730,12 +730,7 @@ public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sam
730730
/// </remarks>
731731
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
732732
{
733-
var dict = new Dictionary<string, object>();
734-
dict["input"] = input;
735-
if (Tout.HasValue)
736-
dict["Tout"] = Tout.Value;
737-
var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict);
738-
return op.output;
733+
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout }));
739734
}
740735

741736
/// <summary>
@@ -4976,15 +4971,14 @@ public static Tensor compare_and_bitpack(Tensor input, Tensor threshold, string
49764971
/// tf.complex(real, imag) ==&amp;gt; [[2.25 + 4.75j], [3.25 + 5.75j]]
49774972
/// </code>
49784973
/// </remarks>
4979-
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
4974+
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex")
49804975
{
4981-
var dict = new Dictionary<string, object>();
4982-
dict["real"] = real;
4983-
dict["imag"] = imag;
4984-
if (Tout.HasValue)
4985-
dict["Tout"] = Tout.Value;
4986-
var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict);
4987-
return op.output;
4976+
TF_DataType Tin = real.GetDataType();
4977+
if (a_Tout is null)
4978+
{
4979+
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64;
4980+
}
4981+
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout }));
49884982
}
49894983

49904984
/// <summary>
@@ -5008,12 +5002,7 @@ public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null,
50085002
/// </remarks>
50095003
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
50105004
{
5011-
var dict = new Dictionary<string, object>();
5012-
dict["x"] = x;
5013-
if (Tout.HasValue)
5014-
dict["Tout"] = Tout.Value;
5015-
var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict);
5016-
return op.output;
5005+
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout }));
50175006
}
50185007

50195008
/// <summary>
@@ -5313,10 +5302,7 @@ public static Tensor configure_distributed_t_p_u(string embedding_config = null,
53135302
/// </remarks>
53145303
public static Tensor conj(Tensor input, string name = "Conj")
53155304
{
5316-
var dict = new Dictionary<string, object>();
5317-
dict["input"] = input;
5318-
var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict);
5319-
return op.output;
5305+
return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input }));
53205306
}
53215307

53225308
/// <summary>
@@ -13325,14 +13311,12 @@ public static Tensor igammac(Tensor a, Tensor x, string name = "Igammac")
1332513311
/// tf.imag(input) ==&amp;gt; [4.75, 5.75]
1332613312
/// </code>
1332713313
/// </remarks>
13328-
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
13314+
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag")
1332913315
{
13330-
var dict = new Dictionary<string, object>();
13331-
dict["input"] = input;
13332-
if (Tout.HasValue)
13333-
dict["Tout"] = Tout.Value;
13334-
var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict);
13335-
return op.output;
13316+
TF_DataType Tin = input.GetDataType();
13317+
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));
13318+
13319+
// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
1333613320
}
1333713321

1333813322
/// <summary>
@@ -23863,14 +23847,12 @@ public static Tensor reader_serialize_state_v2(Tensor reader_handle, string name
2386323847
/// tf.real(input) ==&amp;gt; [-2.25, 3.25]
2386423848
/// </code>
2386523849
/// </remarks>
23866-
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
23850+
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real")
2386723851
{
23868-
var dict = new Dictionary<string, object>();
23869-
dict["input"] = input;
23870-
if (Tout.HasValue)
23871-
dict["Tout"] = Tout.Value;
23872-
var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict);
23873-
return op.output;
23852+
TF_DataType Tin = input.GetDataType();
23853+
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));
23854+
23855+
// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
2387423856
}
2387523857

2387623858
/// <summary>

src/TensorFlowNET.Core/Operations/math_ops.cs

+4-2
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using System.Linq;
2121
using Tensorflow.Framework;
2222
using static Tensorflow.Binding;
23+
using Tensorflow.Operations;
2324

2425
namespace Tensorflow
2526
{
@@ -35,8 +36,9 @@ public static Tensor abs(Tensor x, string name = null)
3536
name = scope;
3637
x = ops.convert_to_tensor(x, name: "x");
3738
if (x.dtype.is_complex())
38-
throw new NotImplementedException("math_ops.abs for dtype.is_complex");
39-
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name);
39+
{
40+
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name);
41+
}
4042
return gen_math_ops._abs(x, name: name);
4143
});
4244
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,202 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using Tensorflow.NumPy;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using Tensorflow;
7+
using static Tensorflow.Binding;
8+
using Buffer = Tensorflow.Buffer;
9+
using TensorFlowNET.Keras.UnitTest;
10+
11+
namespace TensorFlowNET.UnitTest.Basics
12+
{
13+
[TestClass]
14+
public class ComplexTest : EagerModeTestBase
15+
{
16+
// Tests for Complex128
17+
18+
[TestMethod]
19+
public void complex128_basic()
20+
{
21+
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
22+
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };
23+
24+
Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
25+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
26+
27+
Tensor t_complex = tf.complex(t_real, t_imag);
28+
29+
Tensor t_real_result = tf.math.real(t_complex);
30+
Tensor t_imag_result = tf.math.imag(t_complex);
31+
32+
NDArray n_real_result = t_real_result.numpy();
33+
NDArray n_imag_result = t_imag_result.numpy();
34+
35+
double[] d_real_result =n_real_result.ToArray<double>();
36+
double[] d_imag_result = n_imag_result.ToArray<double>();
37+
38+
Assert.IsTrue(base.Equal(d_real_result, d_real));
39+
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
40+
}
41+
[TestMethod]
42+
public void complex128_abs()
43+
{
44+
tf.enable_eager_execution();
45+
46+
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
47+
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };
48+
49+
double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 };
50+
51+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
52+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
53+
54+
Tensor t_complex = tf.complex(t_real, t_imag);
55+
56+
Tensor t_abs_result = tf.abs(t_complex);
57+
58+
double[] d_abs_result = t_abs_result.numpy().ToArray<double>();
59+
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
60+
}
61+
[TestMethod]
62+
public void complex128_conj()
63+
{
64+
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
65+
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };
66+
67+
double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 };
68+
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 };
69+
70+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
71+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
72+
73+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
74+
75+
Tensor t_result = tf.math.conj(t_complex);
76+
77+
NDArray n_real_result = tf.math.real(t_result).numpy();
78+
NDArray n_imag_result = tf.math.imag(t_result).numpy();
79+
80+
double[] d_real_result = n_real_result.ToArray<double>();
81+
double[] d_imag_result = n_imag_result.ToArray<double>();
82+
83+
Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
84+
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
85+
}
86+
[TestMethod]
87+
public void complex128_angle()
88+
{
89+
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 };
90+
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 };
91+
92+
double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 };
93+
94+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
95+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
96+
97+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
98+
99+
Tensor t_result = tf.math.angle(t_complex);
100+
101+
NDArray n_result = t_result.numpy();
102+
103+
double[] d_result = n_result.ToArray<double>();
104+
105+
Assert.IsTrue(base.Equal(d_result, d_expected));
106+
}
107+
108+
// Tests for Complex64
109+
[TestMethod]
110+
public void complex64_basic()
111+
{
112+
tf.init_scope();
113+
float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
114+
float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f };
115+
116+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
117+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
118+
119+
Tensor t_complex = tf.complex(t_real, t_imag);
120+
121+
Tensor t_real_result = tf.math.real(t_complex);
122+
Tensor t_imag_result = tf.math.imag(t_complex);
123+
124+
// Convert the EagerTensors to NumPy arrays directly
125+
float[] d_real_result = t_real_result.numpy().ToArray<float>();
126+
float[] d_imag_result = t_imag_result.numpy().ToArray<float>();
127+
128+
Assert.IsTrue(base.Equal(d_real_result, d_real));
129+
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
130+
}
131+
[TestMethod]
132+
public void complex64_abs()
133+
{
134+
tf.enable_eager_execution();
135+
136+
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
137+
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };
138+
139+
float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f };
140+
141+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
142+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
143+
144+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
145+
146+
Tensor t_abs_result = tf.abs(t_complex);
147+
148+
NDArray n_abs_result = t_abs_result.numpy();
149+
150+
float[] d_abs_result = n_abs_result.ToArray<float>();
151+
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
152+
153+
}
154+
[TestMethod]
155+
public void complex64_conj()
156+
{
157+
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
158+
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };
159+
160+
float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
161+
float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f };
162+
163+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
164+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
165+
166+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
167+
168+
Tensor t_result = tf.math.conj(t_complex);
169+
170+
NDArray n_real_result = tf.math.real(t_result).numpy();
171+
NDArray n_imag_result = tf.math.imag(t_result).numpy();
172+
173+
float[] d_real_result = n_real_result.ToArray<float>();
174+
float[] d_imag_result = n_imag_result.ToArray<float>();
175+
176+
Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
177+
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
178+
179+
}
180+
[TestMethod]
181+
public void complex64_angle()
182+
{
183+
float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f };
184+
float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f };
185+
186+
float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f };
187+
188+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
189+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
190+
191+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
192+
193+
Tensor t_result = tf.math.angle(t_complex);
194+
195+
NDArray n_result = t_result.numpy();
196+
197+
float[] d_result = n_result.ToArray<float>();
198+
199+
Assert.IsTrue(base.Equal(d_result, d_expected));
200+
}
201+
}
202+
}

test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj

+1
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
<ItemGroup>
3838
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
39+
<ProjectReference Include="..\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj" />
3940
</ItemGroup>
4041

4142
</Project>

0 commit comments

Comments
 (0)