Skip to content

Commit f874d3a

Browse files
NiklasGustafssonEsther2013
authored andcommitted
random_uniform calls random_uniform_int when generating integer tensor.
Added doc comment to array_ops.gather(), and implemented using ExecuteOp() Elaborated unit tests for gather, added one for slice()
1 parent 454a55b commit f874d3a

File tree

6 files changed

+131
-41
lines changed

6 files changed

+131
-41
lines changed

src/TensorFlowNET.Core/APIs/tf.random.cs

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -69,7 +69,17 @@ public Tensor random_uniform(TensorShape shape,
6969
float maxval = 1,
7070
TF_DataType dtype = TF_DataType.TF_FLOAT,
7171
int? seed = null,
72-
string name = null) => random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
72+
string name = null)
73+
{
74+
if (dtype.is_integer())
75+
{
76+
return random_ops.random_uniform_int(shape, (int)minval, (int)maxval, dtype, seed, name);
77+
}
78+
else
79+
{
80+
return random_ops.random_uniform(shape, minval, maxval, dtype, seed, name);
81+
}
82+
}
7383

7484
public Tensor truncated_normal(TensorShape shape,
7585
float mean = 0.0f,

src/TensorFlowNET.Core/Operations/array_ops.cs

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,22 @@ public static Tensor concat(object[] values, int axis, string name = "concat")
843843
return gen_array_ops.concat_v2(values, axis, name: name);
844844
}
845845

846-
public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0)
846+
/// <summary>
847+
/// Gather slices from `params` according to `indices`. `indices` must be an integer tensor of any dimension(often 1-D).
848+
/// </summary>
849+
/// <typeparam name="T1">Element type of the indexed tensor.</typeparam>
850+
/// <typeparam name="T2">Element type of the index tensor.</typeparam>
851+
/// <param name="params">The `Tensor` from which to gather values. Must be at least rank `axis + 1`.</param>
852+
/// <param name="indices">The index `Tensor`. Must be one of the following types: `int32`, `int64`. The values must be in range `[0, params.shape[axis])`.</param>
853+
/// <param name="name">A name for the operation (optional).</param>
854+
/// <param name="axis">
855+
/// A `Tensor`. Must be one of the following types: `int32`, `int64`.
856+
/// The `axis` in `params` to gather `indices` from.Must be greater than or equal to `batch_dims`.
857+
/// Defaults to the first non-batch dimension. Supports negative indexes.
858+
/// </param>
859+
/// <param name="batch_dims">An integer. The number of batch dimensions. Must be less than or equal to rank(indices).</param>
860+
/// <returns></returns>
861+
public static Tensor gather<T1, T2>(T1 @params, T2 indices, string name = null, int axis = 0, int batch_dims = 0)
847862
{
848863
if (axis != 0)
849864
return gen_array_ops.gather_v2(@params, indices, axis, name: name);
@@ -913,7 +928,7 @@ private static Tensor[] split_eager_fallback<Ta, Tv>(Ta axis, Tv value, int num_
913928
}
914929

915930
public static Tensor slice(Tensor input, Tensor[] begin, Tensor[] size, string name = null)
916-
=> gen_array_ops.slice(input, begin, size, name: name);
931+
=> gen_array_ops.slice(input, begin, size, name: name);
917932

918933
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
919934
=> gen_array_ops.slice(input, begin, size, name: name);
@@ -928,6 +943,7 @@ public static Tensor slice(Tensor input, Tensor begin, Tensor size, string name
928943
}
929944
});
930945

946+
931947
public static Tensor stack(object values, int axis = 0, string name = "stack")
932948
{
933949
if (axis == 0)

src/TensorFlowNET.Core/Operations/gen_array_ops.cs

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -117,28 +117,13 @@ public static Tensor expand_dims(Tensor input, int axis, string name = null)
117117
=> tf.Context.ExecuteOp("ExpandDims", name, new ExecuteOpArgs(input, axis)
118118
.SetAttributes(new { dim = axis }));
119119

120-
public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, string name = null)
120+
public static Tensor gather_v2<T1, T2>(T1 @params, T2 indices, int axis, int batch_dims = 0, string name = null)
121121
{
122-
if (tf.Context.executing_eagerly())
123-
{
124-
try
125-
{
126-
var results = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("GatherV2", name, @params, indices, axis, "batch_dims", 0)
127-
{
128-
ctx = tf.Context,
129-
device_name = tf.Context.DeviceName
130-
});
131-
return results[0];
132-
}
133-
catch (Exception exc)
134-
{
135-
return gather_v2_eager_fallback(@params, indices, axis, name, tf.Context);
136-
}
137-
}
138-
139-
var _op = tf.OpDefLib._apply_op_helper("GatherV2", name: name, new { @params, indices, axis });
140-
141-
return _op.outputs[0];
122+
var result = tf.Context.ExecuteOp("GatherV2", name, new ExecuteOpArgs(
123+
@params,
124+
indices,
125+
axis).SetAttributes(new { batch_dims }));
126+
return result [0];
142127
}
143128

144129
private static Tensor gather_v2_eager_fallback(object @params, object indices, int axis, string name, Context ctx)
@@ -380,6 +365,12 @@ private static Tensor slice_eager_fallback(Tensor inputs, Tensor[] begin, Tensor
380365

381366
public static Tensor slice<Tb, Ts>(Tensor input, Tb begin, Ts size, string name = null)
382367
{
368+
if (tf.executing_eagerly())
369+
{
370+
var outputs = tf.Runner.TFE_FastPathExecute(new FastPathOpExecInfo("Slice", name, input, begin, size));
371+
return outputs[0];
372+
}
373+
383374
var _op = tf.OpDefLib._apply_op_helper("Slice", name, new { input, begin, size });
384375
return _op.outputs[0];
385376
}

src/TensorFlowNET.Core/Operations/gen_random_ops.cs

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -43,18 +43,8 @@ public static Tensor random_standard_normal(Tensor shape, TF_DataType dtype = TF
4343
/// <param name="name"></param>
4444
/// <returns></returns>
4545
public static Tensor random_uniform_int(Tensor shape, Tensor minval, Tensor maxval, int? seed = 0, int? seed2 = 0, string name = null)
46-
{
47-
if (!seed.HasValue)
48-
seed = 0;
49-
if (!seed2.HasValue)
50-
seed2 = 0;
51-
52-
var _op = tf.OpDefLib._apply_op_helper("RandomUniformInt",
53-
name: name,
54-
args: new { shape, minval, maxval, seed, seed2 });
55-
56-
return _op.outputs[0];
57-
}
46+
=> tf.Context.ExecuteOp("RandomUniformInt", name, new ExecuteOpArgs(shape, minval, maxval)
47+
.SetAttributes(new { seed = seed ?? 0, seed2 = seed2 ?? 0 }));
5848

5949
/// <summary>
6050
/// Outputs random values from a uniform distribution.

src/TensorFlowNET.Core/Operations/random_ops.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,34 @@ public static Tensor random_uniform(int[] shape,
8181
});
8282
}
8383

84+
/// <summary>
85+
/// Outputs random values from a uniform distribution.
86+
/// </summary>
87+
/// <param name="shape"></param>
88+
/// <param name="minval"></param>
89+
/// <param name="maxval"></param>
90+
/// <param name="dtype">The type of the output</param>
91+
/// <param name="seed">Used to create a random seed for the distribution.</param>
92+
/// <param name="name">A name for the operation</param>
93+
/// <returns>A tensor of the specified shape filled with random uniform values.</returns>
94+
public static Tensor random_uniform_int(int[] shape,
95+
int minval = 0,
96+
int maxval = 1,
97+
TF_DataType dtype = TF_DataType.TF_FLOAT,
98+
int? seed = null,
99+
string name = null)
100+
{
101+
return tf_with(ops.name_scope(name, "random_uniform_int", new { shape, minval, maxval }), scope =>
102+
{
103+
name = scope;
104+
var (seed1, seed2) = random_seed.get_seed(seed);
105+
var tensorShape = tensor_util.shape_tensor(shape);
106+
var minTensor = ops.convert_to_tensor(minval, dtype: dtype, name: "min");
107+
var maxTensor = ops.convert_to_tensor(maxval, dtype: dtype, name: "max");
108+
return gen_random_ops.random_uniform_int(tensorShape, minTensor, maxTensor, seed: seed1, seed2: seed2);
109+
});
110+
}
111+
84112
public static Tensor random_uniform(Tensor shape,
85113
int minval = 0,
86114
Tensor maxval = null,
Lines changed: 60 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,55 @@
11
using Microsoft.VisualStudio.TestTools.UnitTesting;
22
using NumSharp;
3+
using NumSharp.Utilities;
34
using Tensorflow;
45
using static Tensorflow.Binding;
56

67
namespace TensorFlowNET.UnitTest.ManagedAPI
78
{
89
[TestClass]
910
public class ArrayOpsTest : EagerModeTestBase
10-
{
11+
{
1112
/// <summary>
12-
/// https://www.tensorflow.org/api_docs/python/tf/keras/layers/Embedding
13+
/// https://www.tensorflow.org/api_docs/python/tf/slice
14+
/// </summary>
15+
[TestMethod]
16+
public void Slice()
17+
{
18+
// Tests based on example code in TF documentation
19+
var input_array = tf.constant(np.array(new int[] { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }).reshape(3,2,3));
20+
var indices = tf.constant(np.array(new int[] { 0, 2 }));
21+
22+
var r1 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 1, 1, 3 });
23+
Assert.AreEqual(new TensorShape(1,1,3), r1.shape);
24+
var r1np = r1.numpy();
25+
Assert.AreEqual(r1np[0, 0, 0], 3);
26+
Assert.AreEqual(r1np[0, 0, 1], 3);
27+
Assert.AreEqual(r1np[0, 0, 2], 3);
28+
29+
30+
var r2 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 1, 2, 3 });
31+
Assert.AreEqual(new TensorShape(1, 2, 3), r2.shape);
32+
var r2np = r2.numpy();
33+
Assert.AreEqual(r2np[0, 0, 0], 3);
34+
Assert.AreEqual(r2np[0, 0, 1], 3);
35+
Assert.AreEqual(r2np[0, 0, 2], 3);
36+
Assert.AreEqual(r2np[0, 1, 0], 4);
37+
Assert.AreEqual(r2np[0, 1, 1], 4);
38+
Assert.AreEqual(r2np[0, 1, 2], 4);
39+
40+
var r3 = array_ops.slice(input_array, new int[] { 1, 0, 0 }, new int[] { 2, 1, 3 });
41+
Assert.AreEqual(new TensorShape(2, 1, 3), r3.shape);
42+
var r3np = r3.numpy();
43+
Assert.AreEqual(r3np[0, 0, 0], 3);
44+
Assert.AreEqual(r3np[0, 0, 1], 3);
45+
Assert.AreEqual(r3np[0, 0, 2], 3);
46+
Assert.AreEqual(r3np[1, 0, 0], 5);
47+
Assert.AreEqual(r3np[1, 0, 1], 5);
48+
Assert.AreEqual(r3np[1, 0, 2], 5);
49+
}
50+
51+
/// <summary>
52+
/// https://www.tensorflow.org/api_docs/python/tf/gather
1353
/// </summary>
1454
[TestMethod]
1555
public void Gather()
@@ -19,9 +59,24 @@ public void Gather()
1959

2060
var result = array_ops.gather(input_array, indices);
2161
Assert.AreEqual(new TensorShape(2, 4), result.shape);
22-
Assert.AreEqual(result.numpy()[0,0], 0.0f);
23-
Assert.AreEqual(result.numpy()[0,1], 1.0f);
24-
Assert.AreEqual(result.numpy()[1,3], 11.0f);
62+
Assert.AreEqual(result.numpy()[0, 0], 0.0f);
63+
Assert.AreEqual(result.numpy()[0, 1], 1.0f);
64+
Assert.AreEqual(result.numpy()[1, 3], 11.0f);
65+
66+
// Tests based on example code in Python doc string for tf.gather()
67+
68+
var p1 = tf.random.normal(new TensorShape(5, 6, 7, 8));
69+
var i1 = tf.random_uniform(new TensorShape(10, 11), maxval: 7, dtype: tf.int32);
70+
var r1 = tf.gather(p1, i1, axis:2);
71+
Assert.AreEqual(new TensorShape(5, 6, 10, 11, 8), r1.shape);
72+
73+
var p2 = tf.random.normal(new TensorShape(4,3));
74+
var i2 = tf.constant(new int[,] { { 0, 2} });
75+
var r2 = tf.gather(p2, i2, axis: 0);
76+
Assert.AreEqual(new TensorShape(1, 2, 3), r2.shape);
77+
78+
var r3 = tf.gather(p2, i2, axis: 1);
79+
Assert.AreEqual(new TensorShape(4,1,2), r3.shape);
2580
}
2681
}
2782
}

0 commit comments

Comments
 (0)