Skip to content

Commit f9409ed

Browse files
committed
Added: complex, real, imag, angle
1 parent 86eb48b commit f9409ed

File tree

6 files changed

+184
-39
lines changed

6 files changed

+184
-39
lines changed
Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
/*****************************************************************************
2+
Copyright 2018 The TensorFlow.NET Authors. All Rights Reserved.
3+
4+
Licensed under the Apache License, Version 2.0 (the "License");
5+
you may not use this file except in compliance with the License.
6+
You may obtain a copy of the License at
7+
8+
http://www.apache.org/licenses/LICENSE-2.0
9+
10+
Unless required by applicable law or agreed to in writing, software
11+
distributed under the License is distributed on an "AS IS" BASIS,
12+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
See the License for the specific language governing permissions and
14+
limitations under the License.
15+
******************************************************************************/
16+
17+
using Tensorflow.Operations;
18+
19+
namespace Tensorflow
20+
{
21+
public partial class tensorflow
22+
{
23+
public Tensor complex(Tensor real, Tensor imag, Tensorflow.TF_DataType? dtype = null,
24+
string name = null) => gen_ops.complex(real, imag, dtype, name);
25+
}
26+
}

src/TensorFlowNET.Core/APIs/tf.math.cs

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ public Tensor softplus(Tensor features, string name = null)
5757

5858
public Tensor tanh(Tensor x, string name = null)
5959
=> math_ops.tanh(x, name: name);
60-
60+
6161
/// <summary>
6262
/// Finds values and indices of the `k` largest entries for the last dimension.
6363
/// </summary>
@@ -93,6 +93,16 @@ public Tensor bincount(Tensor arr, Tensor weights = null,
9393
bool binary_output = false)
9494
=> math_ops.bincount(arr, weights: weights, minlength: minlength, maxlength: maxlength,
9595
dtype: dtype, name: name, axis: axis, binary_output: binary_output);
96+
97+
public Tensor real(Tensor x, string name = null)
98+
=> gen_ops.real(x, x.dtype.real_dtype(), name);
99+
public Tensor imag(Tensor x, string name = null)
100+
=> gen_ops.imag(x, x.dtype.real_dtype(), name);
101+
102+
public Tensor conj(Tensor x, string name = null)
103+
=> gen_ops.conj(x, name);
104+
public Tensor angle(Tensor x, string name = null)
105+
=> gen_ops.angle(x, x.dtype.real_dtype(), name);
96106
}
97107

98108
public Tensor abs(Tensor x, string name = null)
@@ -537,7 +547,7 @@ public Tensor reduce_prod(Tensor input_tensor, Axis? axis = null, bool keepdims
537547
public Tensor reduce_sum(Tensor input, Axis? axis = null, Axis? reduction_indices = null,
538548
bool keepdims = false, string name = null)
539549
{
540-
if(keepdims)
550+
if (keepdims)
541551
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices), keepdims: keepdims, name: name);
542552
else
543553
return math_ops.reduce_sum(input, axis: constant_op.constant(axis ?? reduction_indices));

src/TensorFlowNET.Core/Operations/gen_ops.cs

Lines changed: 8 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -730,12 +730,7 @@ public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sam
730730
/// </remarks>
731731
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
732732
{
733-
var dict = new Dictionary<string, object>();
734-
dict["input"] = input;
735-
if (Tout.HasValue)
736-
dict["Tout"] = Tout.Value;
737-
var op = tf.OpDefLib._apply_op_helper("Angle", name: name, keywords: dict);
738-
return op.output;
733+
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(new object[] { input }));
739734
}
740735

741736
/// <summary>
@@ -4978,15 +4973,11 @@ public static Tensor compare_and_bitpack(Tensor input, Tensor threshold, string
49784973
/// </remarks>
49794974
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
49804975
{
4981-
var dict = new Dictionary<string, object>();
4982-
dict["real"] = real;
4983-
dict["imag"] = imag;
4984-
if (Tout.HasValue)
4985-
dict["Tout"] = Tout.Value;
4986-
var op = tf.OpDefLib._apply_op_helper("Complex", name: name, keywords: dict);
4987-
return op.output;
4976+
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(new object[] { real, imag })); // sorry, cannot pass Tout, so it only works with complex64. complex128 is not supported yet
49884977
}
49894978

4979+
4980+
49904981
/// <summary>
49914982
/// Computes the complex absolute value of a tensor.
49924983
/// </summary>
@@ -5008,12 +4999,7 @@ public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null,
50084999
/// </remarks>
50095000
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
50105001
{
5011-
var dict = new Dictionary<string, object>();
5012-
dict["x"] = x;
5013-
if (Tout.HasValue)
5014-
dict["Tout"] = Tout.Value;
5015-
var op = tf.OpDefLib._apply_op_helper("ComplexAbs", name: name, keywords: dict);
5016-
return op.output;
5002+
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(new object[] { x }));
50175003
}
50185004

50195005
/// <summary>
@@ -5313,10 +5299,7 @@ public static Tensor configure_distributed_t_p_u(string embedding_config = null,
53135299
/// </remarks>
53145300
public static Tensor conj(Tensor input, string name = "Conj")
53155301
{
5316-
var dict = new Dictionary<string, object>();
5317-
dict["input"] = input;
5318-
var op = tf.OpDefLib._apply_op_helper("Conj", name: name, keywords: dict);
5319-
return op.output;
5302+
return tf.Context.ExecuteOp("Conj", name, new ExecuteOpArgs(new object[] { input }));
53205303
}
53215304

53225305
/// <summary>
@@ -13327,12 +13310,7 @@ public static Tensor igammac(Tensor a, Tensor x, string name = "Igammac")
1332713310
/// </remarks>
1332813311
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
1332913312
{
13330-
var dict = new Dictionary<string, object>();
13331-
dict["input"] = input;
13332-
if (Tout.HasValue)
13333-
dict["Tout"] = Tout.Value;
13334-
var op = tf.OpDefLib._apply_op_helper("Imag", name: name, keywords: dict);
13335-
return op.output;
13313+
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
1333613314
}
1333713315

1333813316
/// <summary>
@@ -23865,12 +23843,7 @@ public static Tensor reader_serialize_state_v2(Tensor reader_handle, string name
2386523843
/// </remarks>
2386623844
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
2386723845
{
23868-
var dict = new Dictionary<string, object>();
23869-
dict["input"] = input;
23870-
if (Tout.HasValue)
23871-
dict["Tout"] = Tout.Value;
23872-
var op = tf.OpDefLib._apply_op_helper("Real", name: name, keywords: dict);
23873-
return op.output;
23846+
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
2387423847
}
2387523848

2387623849
/// <summary>

src/TensorFlowNET.Core/Operations/math_ops.cs

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License.
2020
using System.Linq;
2121
using Tensorflow.Framework;
2222
using static Tensorflow.Binding;
23+
using Tensorflow.Operations;
2324

2425
namespace Tensorflow
2526
{
@@ -35,8 +36,9 @@ public static Tensor abs(Tensor x, string name = null)
3536
name = scope;
3637
x = ops.convert_to_tensor(x, name: "x");
3738
if (x.dtype.is_complex())
38-
throw new NotImplementedException("math_ops.abs for dtype.is_complex");
39-
//return gen_math_ops.complex_abs(x, Tout: x.dtype.real_dtype, name: name);
39+
{
40+
return gen_ops.complex_abs(x, Tout: x.dtype.real_dtype(), name: name);
41+
}
4042
return gen_math_ops._abs(x, name: name);
4143
});
4244
}
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
using Microsoft.VisualStudio.TestTools.UnitTesting;
2+
using Tensorflow.NumPy;
3+
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
6+
using Tensorflow;
7+
using static Tensorflow.Binding;
8+
using Buffer = Tensorflow.Buffer;
9+
using TensorFlowNET.Keras.UnitTest;
10+
11+
namespace TensorFlowNET.UnitTest.Basics
12+
{
13+
[TestClass]
14+
public class ComplexTest : EagerModeTestBase
15+
{
16+
[Ignore("Not working")]
17+
[TestMethod]
18+
public void complex128_basic()
19+
{
20+
double[] d_real = new double[] { 1.0, 2.0, 3.0, 4.0 };
21+
double[] d_imag = new double[] { -1.0, -3.0, 5.0, 7.0 };
22+
23+
Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
24+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
25+
26+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
27+
28+
Tensor t_real_result = tf.math.real(t_complex);
29+
Tensor t_imag_result = tf.math.imag(t_complex);
30+
31+
NDArray n_real_result = t_real_result.numpy();
32+
NDArray n_imag_result = t_imag_result.numpy();
33+
34+
double[] d_real_result =n_real_result.ToArray<double>();
35+
double[] d_imag_result = n_imag_result.ToArray<double>();
36+
37+
Assert.AreEqual(d_real_result, d_real);
38+
Assert.AreEqual(d_imag_result, d_imag);
39+
}
40+
[TestMethod]
41+
public void complex64_basic()
42+
{
43+
tf.init_scope();
44+
float[] d_real = new float[] { 1.0f, 2.0f, 3.0f, 4.0f };
45+
float[] d_imag = new float[] { -1.0f, -3.0f, 5.0f, 7.0f };
46+
47+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
48+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
49+
50+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
51+
52+
Tensor t_real_result = tf.math.real(t_complex);
53+
Tensor t_imag_result = tf.math.imag(t_complex);
54+
55+
// Convert the EagerTensors to NumPy arrays directly
56+
float[] d_real_result = t_real_result.numpy().ToArray<float>();
57+
float[] d_imag_result = t_imag_result.numpy().ToArray<float>();
58+
59+
Assert.IsTrue(base.Equal(d_real_result, d_real));
60+
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
61+
}
62+
[TestMethod]
63+
public void complex64_abs()
64+
{
65+
tf.enable_eager_execution();
66+
67+
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
68+
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };
69+
70+
float[] d_abs = new float[] { 5.0f, 13.0f, 17.0f, 25.0f };
71+
72+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
73+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
74+
75+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
76+
77+
Tensor t_abs_result = tf.abs(t_complex);
78+
79+
NDArray n_abs_result = t_abs_result.numpy();
80+
81+
float[] d_abs_result = n_abs_result.ToArray<float>();
82+
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
83+
84+
}
85+
[TestMethod]
86+
public void complex64_conj()
87+
{
88+
float[] d_real = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
89+
float[] d_imag = new float[] { -4.0f, 12.0f, -15.0f, 24.0f };
90+
91+
float[] d_real_expected = new float[] { -3.0f, -5.0f, 8.0f, 7.0f };
92+
float[] d_imag_expected = new float[] { 4.0f, -12.0f, 15.0f, -24.0f };
93+
94+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
95+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
96+
97+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
98+
99+
Tensor t_result = tf.math.conj(t_complex);
100+
101+
NDArray n_real_result = tf.math.real(t_result).numpy();
102+
NDArray n_imag_result = tf.math.imag(t_result).numpy();
103+
104+
float[] d_real_result = n_real_result.ToArray<float>();
105+
float[] d_imag_result = n_imag_result.ToArray<float>();
106+
107+
Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
108+
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
109+
110+
}
111+
[TestMethod]
112+
public void complex64_angle()
113+
{
114+
float[] d_real = new float[] { 0.0f, 1.0f, -1.0f, 0.0f };
115+
float[] d_imag = new float[] { 1.0f, 0.0f, -2.0f, -3.0f };
116+
117+
float[] d_expected = new float[] { 1.5707964f, 0f, -2.0344439f, -1.5707964f };
118+
119+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
120+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
121+
122+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
123+
124+
Tensor t_result = tf.math.angle(t_complex);
125+
126+
NDArray n_result = t_result.numpy();
127+
128+
float[] d_result = n_result.ToArray<float>();
129+
130+
Assert.IsTrue(base.Equal(d_result, d_expected));
131+
}
132+
}
133+
}

test/TensorFlowNET.Graph.UnitTest/TensorFlowNET.Graph.UnitTest.csproj

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@
3636

3737
<ItemGroup>
3838
<ProjectReference Include="..\..\src\TensorFlowNET.Core\Tensorflow.Binding.csproj" />
39+
<ProjectReference Include="..\TensorFlowNET.Keras.UnitTest\Tensorflow.Keras.UnitTest.csproj" />
3940
</ItemGroup>
4041

4142
</Project>

0 commit comments

Comments
 (0)