@@ -14,10 +14,15 @@ You may obtain a copy of the License at
14
14
limitations under the License.
15
15
******************************************************************************/
16
16
17
+ using Serilog . Debugging ;
17
18
using System ;
19
+ using System . Collections . Concurrent ;
18
20
using System . Collections . Generic ;
21
+ //using System.ComponentModel.DataAnnotations;
19
22
using System . Text ;
23
+ using System . Xml . Linq ;
20
24
using Tensorflow . Framework ;
25
+ using Tensorflow . NumPy ;
21
26
using static Tensorflow . Binding ;
22
27
23
28
namespace Tensorflow
@@ -99,5 +104,55 @@ public static RowPartition from_row_splits(Tensor row_splits,
99
104
return new RowPartition ( row_splits ) ;
100
105
} ) ;
101
106
}
107
+
108
+ public static RowPartition from_row_lengths ( Tensor row_lengths ,
109
+ bool validate = true ,
110
+ TF_DataType dtype = TF_DataType . TF_INT32 ,
111
+ TF_DataType dtype_hint = TF_DataType . TF_INT32 )
112
+ {
113
+ row_lengths = _convert_row_partition (
114
+ row_lengths , "row_lengths" , dtype_hint : dtype_hint , dtype : dtype ) ;
115
+ Tensor row_limits = math_ops . cumsum < Tensor > ( row_lengths , tf . constant ( - 1 ) ) ;
116
+ Tensor row_splits = array_ops . concat ( new Tensor [ ] { tf . convert_to_tensor ( np . array ( new int [ ] { 0 } , TF_DataType . TF_INT64 ) ) , row_limits } , axis : 0 ) ;
117
+ return new RowPartition ( row_splits : row_splits , row_lengths : row_lengths ) ;
118
+ }
119
+
120
+ public static Tensor _convert_row_partition ( Tensor partition , string name , TF_DataType dtype ,
121
+ TF_DataType dtype_hint = TF_DataType . TF_INT64 )
122
+ {
123
+ if ( partition is NDArray && partition . GetDataType ( ) == np . int32 ) partition = ops . convert_to_tensor ( partition , name : name ) ;
124
+ if ( partition . GetDataType ( ) != np . int32 && partition . GetDataType ( ) != np . int64 ) throw new ValueError ( $ "{ name } must have dtype int32 or int64") ;
125
+ return partition ;
126
+ }
127
+
128
+ public Tensor nrows ( )
129
+ {
130
+ /*Returns the number of rows created by this `RowPartition*/
131
+ if ( this . _nrows != null ) return this . _nrows ;
132
+ var nsplits = tensor_shape . dimension_at_index ( this . _row_splits . shape , 0 ) ;
133
+ if ( nsplits == null ) return array_ops . shape ( this . _row_splits , out_type : this . row_splits . dtype ) [ 0 ] - 1 ;
134
+ else return constant_op . constant ( nsplits . value - 1 , dtype : this . row_splits . dtype ) ;
135
+ }
136
+
137
+ public Tensor row_lengths ( )
138
+ {
139
+
140
+ if ( this . _row_splits != null )
141
+ {
142
+ int nrows_plus_one = tensor_shape . dimension_value ( this . _row_splits . shape [ 0 ] ) ;
143
+ return tf . constant ( nrows_plus_one - 1 ) ;
144
+
145
+ }
146
+ if ( this . _row_lengths != null )
147
+ {
148
+ var nrows = tensor_shape . dimension_value ( this . _row_lengths . shape [ 0 ] ) ;
149
+ return tf . constant ( nrows ) ;
150
+ }
151
+ if ( this . _nrows != null )
152
+ {
153
+ return tensor_util . constant_value ( this . _nrows ) ;
154
+ }
155
+ return tf . constant ( - 1 ) ;
156
+ }
102
157
}
103
158
}
0 commit comments