Skip to content

Commit febba7b

Browse files
committed
Added support for Complex128 and unit tests for it.
1 parent 1a62163 commit febba7b

File tree

2 files changed

+93
-15
lines changed

2 files changed

+93
-15
lines changed

src/TensorFlowNET.Core/Operations/gen_ops.cs

+19-10
Original file line numberDiff line numberDiff line change
@@ -730,7 +730,7 @@ public static (Tensor sampled_candidates, Tensor true_expected_count, Tensor sam
730730
/// </remarks>
731731
public static Tensor angle(Tensor input, TF_DataType? Tout = null, string name = "Angle")
732732
{
733-
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(new object[] { input }));
733+
return tf.Context.ExecuteOp("Angle", name, new ExecuteOpArgs(input).SetAttributes(new { Tout = Tout }));
734734
}
735735

736736
/// <summary>
@@ -4971,13 +4971,16 @@ public static Tensor compare_and_bitpack(Tensor input, Tensor threshold, string
49714971
/// tf.complex(real, imag) ==&amp;gt; [[2.25 + 4.75j], [3.25 + 5.75j]]
49724972
/// </code>
49734973
/// </remarks>
4974-
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null, string name = "Complex")
4974+
public static Tensor complex(Tensor real, Tensor imag, TF_DataType? a_Tout = null, string name = "Complex")
49754975
{
4976-
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(new object[] { real, imag })); // sorry, cannot pass Tout, so it only works with complex64. complex128 is not supported yet
4976+
TF_DataType Tin = real.GetDataType();
4977+
if (a_Tout is null)
4978+
{
4979+
a_Tout = (Tin == TF_DataType.TF_DOUBLE)? TF_DataType.TF_COMPLEX128: TF_DataType.TF_COMPLEX64;
4980+
}
4981+
return tf.Context.ExecuteOp("Complex", name, new ExecuteOpArgs(real, imag).SetAttributes(new { T=Tin, Tout=a_Tout }));
49774982
}
49784983

4979-
4980-
49814984
/// <summary>
49824985
/// Computes the complex absolute value of a tensor.
49834986
/// </summary>
@@ -4999,7 +5002,7 @@ public static Tensor complex(Tensor real, Tensor imag, TF_DataType? Tout = null,
49995002
/// </remarks>
50005003
public static Tensor complex_abs(Tensor x, TF_DataType? Tout = null, string name = "ComplexAbs")
50015004
{
5002-
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(new object[] { x }));
5005+
return tf.Context.ExecuteOp("ComplexAbs", name, new ExecuteOpArgs(x).SetAttributes(new { Tout = Tout }));
50035006
}
50045007

50055008
/// <summary>
@@ -13308,9 +13311,12 @@ public static Tensor igammac(Tensor a, Tensor x, string name = "Igammac")
1330813311
/// tf.imag(input) ==&amp;gt; [4.75, 5.75]
1330913312
/// </code>
1331013313
/// </remarks>
13311-
public static Tensor imag(Tensor input, TF_DataType? Tout = null, string name = "Imag")
13314+
public static Tensor imag(Tensor input, TF_DataType? a_Tout = null, string name = "Imag")
1331213315
{
13313-
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
13316+
TF_DataType Tin = input.GetDataType();
13317+
return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));
13318+
13319+
// return tf.Context.ExecuteOp("Imag", name, new ExecuteOpArgs(new object[] { input }));
1331413320
}
1331513321

1331613322
/// <summary>
@@ -23841,9 +23847,12 @@ public static Tensor reader_serialize_state_v2(Tensor reader_handle, string name
2384123847
/// tf.real(input) ==&amp;gt; [-2.25, 3.25]
2384223848
/// </code>
2384323849
/// </remarks>
23844-
public static Tensor real(Tensor input, TF_DataType? Tout = null, string name = "Real")
23850+
public static Tensor real(Tensor input, TF_DataType? a_Tout = null, string name = "Real")
2384523851
{
23846-
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
23852+
TF_DataType Tin = input.GetDataType();
23853+
return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(input).SetAttributes(new { T = Tin, Tout = a_Tout }));
23854+
23855+
// return tf.Context.ExecuteOp("Real", name, new ExecuteOpArgs(new object[] {input}));
2384723856
}
2384823857

2384923858
/// <summary>

test/TensorFlowNET.Graph.UnitTest/ComplexTest.cs

+74-5
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ namespace TensorFlowNET.UnitTest.Basics
1313
[TestClass]
1414
public class ComplexTest : EagerModeTestBase
1515
{
16-
[Ignore("Not working")]
16+
// Tests for Complex128
17+
1718
[TestMethod]
1819
public void complex128_basic()
1920
{
@@ -23,7 +24,7 @@ public void complex128_basic()
2324
Tensor t_real = tf.constant(d_real, dtype:TF_DataType.TF_DOUBLE);
2425
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
2526

26-
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
27+
Tensor t_complex = tf.complex(t_real, t_imag);
2728

2829
Tensor t_real_result = tf.math.real(t_complex);
2930
Tensor t_imag_result = tf.math.imag(t_complex);
@@ -34,9 +35,77 @@ public void complex128_basic()
3435
double[] d_real_result =n_real_result.ToArray<double>();
3536
double[] d_imag_result = n_imag_result.ToArray<double>();
3637

37-
Assert.AreEqual(d_real_result, d_real);
38-
Assert.AreEqual(d_imag_result, d_imag);
38+
Assert.IsTrue(base.Equal(d_real_result, d_real));
39+
Assert.IsTrue(base.Equal(d_imag_result, d_imag));
40+
}
41+
[TestMethod]
42+
public void complex128_abs()
43+
{
44+
tf.enable_eager_execution();
45+
46+
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
47+
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };
48+
49+
double[] d_abs = new double[] { 5.0, 13.0, 17.0, 25.0 };
50+
51+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
52+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
53+
54+
Tensor t_complex = tf.complex(t_real, t_imag);
55+
56+
Tensor t_abs_result = tf.abs(t_complex);
57+
58+
double[] d_abs_result = t_abs_result.numpy().ToArray<double>();
59+
Assert.IsTrue(base.Equal(d_abs_result, d_abs));
60+
}
61+
[TestMethod]
62+
public void complex128_conj()
63+
{
64+
double[] d_real = new double[] { -3.0, -5.0, 8.0, 7.0 };
65+
double[] d_imag = new double[] { -4.0, 12.0, -15.0, 24.0 };
66+
67+
double[] d_real_expected = new double[] { -3.0, -5.0, 8.0, 7.0 };
68+
double[] d_imag_expected = new double[] { 4.0, -12.0, 15.0, -24.0 };
69+
70+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
71+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
72+
73+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
74+
75+
Tensor t_result = tf.math.conj(t_complex);
76+
77+
NDArray n_real_result = tf.math.real(t_result).numpy();
78+
NDArray n_imag_result = tf.math.imag(t_result).numpy();
79+
80+
double[] d_real_result = n_real_result.ToArray<double>();
81+
double[] d_imag_result = n_imag_result.ToArray<double>();
82+
83+
Assert.IsTrue(base.Equal(d_real_result, d_real_expected));
84+
Assert.IsTrue(base.Equal(d_imag_result, d_imag_expected));
85+
}
86+
[TestMethod]
87+
public void complex128_angle()
88+
{
89+
double[] d_real = new double[] { 0.0, 1.0, -1.0, 0.0 };
90+
double[] d_imag = new double[] { 1.0, 0.0, -2.0, -3.0 };
91+
92+
double[] d_expected = new double[] { 1.5707963267948966, 0, -2.0344439357957027, -1.5707963267948966 };
93+
94+
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_DOUBLE);
95+
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_DOUBLE);
96+
97+
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX128);
98+
99+
Tensor t_result = tf.math.angle(t_complex);
100+
101+
NDArray n_result = t_result.numpy();
102+
103+
double[] d_result = n_result.ToArray<double>();
104+
105+
Assert.IsTrue(base.Equal(d_result, d_expected));
39106
}
107+
108+
// Tests for Complex64
40109
[TestMethod]
41110
public void complex64_basic()
42111
{
@@ -47,7 +116,7 @@ public void complex64_basic()
47116
Tensor t_real = tf.constant(d_real, dtype: TF_DataType.TF_FLOAT);
48117
Tensor t_imag = tf.constant(d_imag, dtype: TF_DataType.TF_FLOAT);
49118

50-
Tensor t_complex = tf.complex(t_real, t_imag, TF_DataType.TF_COMPLEX64);
119+
Tensor t_complex = tf.complex(t_real, t_imag);
51120

52121
Tensor t_real_result = tf.math.real(t_complex);
53122
Tensor t_imag_result = tf.math.imag(t_complex);

0 commit comments

Comments
 (0)