Skip to content

Commit 15763df

Browse files
authored
Merge pull request #1186 from Beacontownfc/ragged
Improve RaggedTensor
2 parents eb4c1f4 + 02bfb9a commit 15763df

File tree

4 files changed

+127
-0
lines changed

4 files changed

+127
-0
lines changed

src/TensorFlowNET.Core/Operations/array_ops.cs

+13
Original file line numberDiff line numberDiff line change
@@ -1139,5 +1139,18 @@ public static Tensor placeholder(TF_DataType dtype, Shape shape = null, string n
11391139
var _op = tf.OpDefLib._apply_op_helper("Placeholder", name: name, args: new { dtype, shape });
11401140
return _op.output;
11411141
}
1142+
1143+
public static int get_positive_axis(int axis, int ndims=-100, string axis_name="axis", string ndims_name= "ndims")
1144+
{
1145+
if(ndims != -100)
1146+
{
1147+
if (axis >= 0 && axis < ndims) return axis;
1148+
else if (-ndims <= axis && axis < 0) return axis + ndims;
1149+
else throw new ValueError($"{axis_name}={axis} out of bounds:expected {-ndims}<={axis_name}<{ndims}");
1150+
1151+
} else if(axis < 0) throw new ValueError($"{axis_name}={axis} may only be negative if {ndims_name} is statically known.");
1152+
return axis;
1153+
}
1154+
11421155
}
11431156
}

src/TensorFlowNET.Core/Tensors/Ragged/RaggedTensor.cs

+33
Original file line numberDiff line numberDiff line change
@@ -163,5 +163,38 @@ public static implicit operator RaggedTensor(Tensor tensor)
163163
{
164164
return tensor.Tag as RaggedTensor;
165165
}
166+
public Tensor nrows(TF_DataType out_type, string name = null)
167+
{
168+
tf_with(ops.name_scope(name, "RaggedNRows"), scope =>
169+
{
170+
return math_ops.cast(this._row_partition.nrows(), dtype: out_type);
171+
});
172+
return null;
173+
}
174+
public RaggedTensor row_lengths(int axis=-1, string name=null)
175+
{
176+
if (axis == 0) return this._row_partition.nrows();
177+
if (axis == 1) return this._row_partition.row_lengths();
178+
var values = (RaggedTensor)this._values;
179+
axis = array_ops.get_positive_axis(
180+
axis, this.shape.rank, ndims_name: "rank(this)");
181+
if (axis == 0) return this.nrows(this._row_partition.GetDataType());
182+
else if (axis == 1)
183+
{
184+
var splits = this._row_partition.row_splits;
185+
return splits[new Slice(start: 1)] - splits[new Slice(stop: -1)];
186+
187+
}
188+
else if (this._values is RaggedTensor)
189+
{
190+
return values.row_lengths(axis - 1);
191+
}
192+
else
193+
{
194+
var shape = array_ops.shape(values, out_type: this._row_partition.GetDataType());
195+
return array_ops.ones(shape[new Slice(stop:axis - 1)], this._row_partition.GetDataType()) *
196+
shape[axis - 1];
197+
}
198+
}
166199
}
167200
}

src/TensorFlowNET.Core/Tensors/Ragged/RowPartition.cs

+55
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,15 @@ You may obtain a copy of the License at
1414
limitations under the License.
1515
******************************************************************************/
1616

17+
using Serilog.Debugging;
1718
using System;
19+
using System.Collections.Concurrent;
1820
using System.Collections.Generic;
21+
//using System.ComponentModel.DataAnnotations;
1922
using System.Text;
23+
using System.Xml.Linq;
2024
using Tensorflow.Framework;
25+
using Tensorflow.NumPy;
2126
using static Tensorflow.Binding;
2227

2328
namespace Tensorflow
@@ -99,5 +104,55 @@ public static RowPartition from_row_splits(Tensor row_splits,
99104
return new RowPartition(row_splits);
100105
});
101106
}
107+
108+
public static RowPartition from_row_lengths(Tensor row_lengths,
109+
bool validate=true,
110+
TF_DataType dtype = TF_DataType.TF_INT32,
111+
TF_DataType dtype_hint= TF_DataType.TF_INT32)
112+
{
113+
row_lengths = _convert_row_partition(
114+
row_lengths, "row_lengths", dtype_hint: dtype_hint, dtype: dtype);
115+
Tensor row_limits = math_ops.cumsum<Tensor>(row_lengths, tf.constant(-1));
116+
Tensor row_splits = array_ops.concat(new Tensor[] { tf.convert_to_tensor(np.array(new int[] { 0 }, TF_DataType.TF_INT64)), row_limits }, axis:0);
117+
return new RowPartition(row_splits: row_splits, row_lengths: row_lengths);
118+
}
119+
120+
public static Tensor _convert_row_partition(Tensor partition, string name, TF_DataType dtype,
121+
TF_DataType dtype_hint= TF_DataType.TF_INT64)
122+
{
123+
if (partition is NDArray && partition.GetDataType() == np.int32) partition = ops.convert_to_tensor(partition, name: name);
124+
if (partition.GetDataType() != np.int32 && partition.GetDataType() != np.int64) throw new ValueError($"{name} must have dtype int32 or int64");
125+
return partition;
126+
}
127+
128+
public Tensor nrows()
129+
{
130+
/*Returns the number of rows created by this `RowPartition*/
131+
if (this._nrows != null) return this._nrows;
132+
var nsplits = tensor_shape.dimension_at_index(this._row_splits.shape, 0);
133+
if (nsplits == null) return array_ops.shape(this._row_splits, out_type: this.row_splits.dtype)[0] - 1;
134+
else return constant_op.constant(nsplits.value - 1, dtype: this.row_splits.dtype);
135+
}
136+
137+
public Tensor row_lengths()
138+
{
139+
140+
if (this._row_splits != null)
141+
{
142+
int nrows_plus_one = tensor_shape.dimension_value(this._row_splits.shape[0]);
143+
return tf.constant(nrows_plus_one - 1);
144+
145+
}
146+
if (this._row_lengths != null)
147+
{
148+
var nrows = tensor_shape.dimension_value(this._row_lengths.shape[0]);
149+
return tf.constant(nrows);
150+
}
151+
if(this._nrows != null)
152+
{
153+
return tensor_util.constant_value(this._nrows);
154+
}
155+
return tf.constant(-1);
156+
}
102157
}
103158
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
using System;
2+
using System.Collections.Generic;
3+
using System.Linq;
4+
using System.Text;
5+
using System.Threading.Tasks;
6+
using Microsoft.VisualStudio.TestTools.UnitTesting;
7+
using Tensorflow;
8+
using Tensorflow.NumPy;
9+
using static Tensorflow.Binding;
10+
11+
namespace TensorFlowNET.UnitTest.ManagedAPI
12+
{
13+
public class RaggedTensorTest :EagerModeTestBase
14+
{
15+
[TestMethod]
16+
public void Test_from_row_lengths()
17+
{
18+
var row_lengths = tf.convert_to_tensor(np.array(new int[] { 2, 0, 3, 1, 1 }, TF_DataType.TF_INT64));
19+
var rp = RowPartition.from_row_lengths(row_lengths, validate: false);
20+
var rp_row_lengths = rp.row_lengths();
21+
var rp_nrows = rp.nrows();
22+
Assert.IsTrue(rp_nrows.ToArray<long>()[0] == rp.nrows().ToArray<long>()[0]);
23+
24+
}
25+
}
26+
}

0 commit comments

Comments
 (0)