@@ -423,7 +423,10 @@ class GatherV2:
423
423
@classmethod
424
424
def version_1 (cls , ctx , node , ** kwargs ):
425
425
# for GatherV2 axis come as input
426
+ err_msg = "Opset 12 required for batch_dims attribute of GatherV2"
427
+ utils .make_sure (node .get_attr_value ("batch_dims" , 0 ) == 0 , err_msg )
426
428
node .type = "Gather"
429
+ utils .make_sure (node .inputs [2 ].is_const (), "Axis of GatherV2 node must be constant" )
427
430
axis = node .inputs [2 ].get_tensor_value ()
428
431
ctx .remove_input (node , node .input [2 ], 2 )
429
432
node .set_attr ("axis" , axis )
@@ -433,6 +436,42 @@ def version_11(cls, ctx, node, **kwargs):
433
436
# no change
434
437
cls .version_1 (ctx , node , ** kwargs )
435
438
439
+ @classmethod
440
+ def version_12 (cls , ctx , node , ** kwargs ):
441
+ batch_dims = node .get_attr_value ("batch_dims" , 0 )
442
+ if batch_dims == 0 :
443
+ cls .version_1 (ctx , node , ** kwargs )
444
+ return
445
+ # If batch_dims is not zero, use GatherND to simulate Gather with batch dims.
446
+ data_inp , indices_inp , axis_inp = node .input
447
+ utils .make_sure (node .inputs [2 ].is_const (), "Axis of GatherV2 node must be constant" )
448
+ axis = node .inputs [2 ].get_tensor_value ()
449
+ ctx .remove_input (node , axis_inp , 2 )
450
+ if ctx .get_dtype (indices_inp ) != TensorProto .INT64 :
451
+ indices_inp = ctx .make_node ("Cast" , [indices_inp ], attr = {'to' : TensorProto .INT64 }).output [0 ]
452
+ unperm = None
453
+ # GatherND doesn't take an axis so we have to transpose stuff around
454
+ if axis != batch_dims :
455
+ data_rank = ctx .get_rank (data_inp )
456
+ indices_rank = ctx .get_rank (indices_inp )
457
+ result_rank = data_rank + indices_rank - 1 - batch_dims
458
+ shift_amt = axis - batch_dims
459
+ err_msg = "Cannot convert GatherV2 with batch dims since inputs have unknown ranks."
460
+ utils .make_sure (data_rank is not None and indices_rank is not None , err_msg )
461
+ perm = list (range (data_rank ))
462
+ perm = perm [:batch_dims ] + perm [axis :axis + 1 ] + perm [batch_dims :axis ] + perm [axis + 1 :]
463
+ data_inp = ctx .make_node ("Transpose" , [data_inp ], attr = {'perm' : perm }).output [0 ]
464
+ ctx .replace_input (node , node .input [0 ], data_inp , 0 )
465
+ unperm = list (range (result_rank ))
466
+ j = indices_rank + shift_amt
467
+ unperm = unperm [:batch_dims ] + unperm [indices_rank :j ] + unperm [batch_dims :indices_rank ] + unperm [j :]
468
+ node .type = "GatherND"
469
+ unsqueeze_node = GraphBuilder (ctx ).make_unsqueeze ({'data' : indices_inp , 'axes' : [- 1 ]})
470
+ ctx .replace_input (node , node .input [1 ], unsqueeze_node , 1 )
471
+ if unperm is not None :
472
+ ctx .update_node_shape_dtype (node , override = True )
473
+ ctx .insert_new_node_on_output ("Transpose" , node .output [0 ], perm = unperm )
474
+
436
475
437
476
def _make_gathernd_inner_loop (ctx , params , index , dtype ):
438
477
"""create the inner loop for GatherNd."""
@@ -2077,43 +2116,77 @@ def ragged_lengths_to_sparse_indices(ctx, ragged_lens):
2077
2116
return num_rows , num_cols , row_indices , col_indices
2078
2117
2079
2118
2119
+ def ragged_nested_splits_to_sparse_indices (ctx , nested_splits , op_name_scope ):
2120
+ sparse_indices = None
2121
+ dense_shape_dims = []
2122
+ for split in nested_splits :
2123
+ if ctx .get_dtype (split ) != TensorProto .INT64 :
2124
+ split = ctx .make_node ("Cast" , [split ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2125
+ max_int64 = int (utils .get_max_value (np .int64 ))
2126
+ slice1 = GraphBuilder (ctx ).make_slice (
2127
+ {"data" : split , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2128
+ slice2 = GraphBuilder (ctx ).make_slice (
2129
+ {"data" : split , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2130
+ ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2131
+ num_rows , num_cols , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , ragged_lens )
2132
+ if not dense_shape_dims :
2133
+ dense_shape_dims .append (num_rows )
2134
+ dense_shape_dims .append (num_cols )
2135
+ if sparse_indices is None :
2136
+ row_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_indices , "axes" : [1 ]})
2137
+ else :
2138
+ row_indices = ctx .make_node ("Gather" , [sparse_indices , row_indices ]).output [0 ]
2139
+ col_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : col_indices , "axes" : [1 ]})
2140
+ sparse_indices = ctx .make_node ("Concat" , [row_indices , col_indices ], attr = {'axis' : 1 },
2141
+ op_name_scope = op_name_scope ).output [0 ]
2142
+ dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = op_name_scope ).output [0 ]
2143
+ return sparse_indices , dense_shape
2144
+
2145
+
2080
2146
@tf_op ("RaggedTensorToSparse" )
2081
2147
class RaggedTensorToSparse :
2082
2148
@classmethod
2083
2149
def version_11 (cls , ctx , node , ** kwargs ):
2084
2150
# https://www.tensorflow.org/guide/ragged_tensor#multiple_ragged_dimensions
2085
2151
dense_values = node .input [- 1 ]
2086
2152
nested_splits = node .input [:- 1 ]
2087
- sparse_indices = None
2088
- dense_shape_dims = []
2089
- for split in nested_splits :
2090
- if ctx .get_dtype (split ) != TensorProto .INT64 :
2091
- split = ctx .make_node ("Cast" , [split ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2092
- max_int64 = int (utils .get_max_value (np .int64 ))
2093
- slice1 = GraphBuilder (ctx ).make_slice (
2094
- {"data" : split , "ends" : [max_int64 ], "starts" : [1 ], "axes" : [0 ]})
2095
- slice2 = GraphBuilder (ctx ).make_slice (
2096
- {"data" : split , "ends" : [- 1 ], "starts" : [0 ], "axes" : [0 ]})
2097
- ragged_lens = ctx .make_node ("Sub" , [slice1 , slice2 ]).output [0 ]
2098
- num_rows , num_cols , row_indices , col_indices = ragged_lengths_to_sparse_indices (ctx , ragged_lens )
2099
- if not dense_shape_dims :
2100
- dense_shape_dims .append (num_rows )
2101
- dense_shape_dims .append (num_cols )
2102
- if sparse_indices is None :
2103
- row_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : row_indices , "axes" : [1 ]})
2104
- else :
2105
- row_indices = ctx .make_node ("Gather" , [sparse_indices , row_indices ]).output [0 ]
2106
- col_indices = GraphBuilder (ctx ).make_unsqueeze ({"data" : col_indices , "axes" : [1 ]})
2107
- sparse_indices = ctx .make_node ("Concat" , [row_indices , col_indices ], attr = {'axis' : 1 },
2108
- op_name_scope = node .name ).output [0 ]
2109
- dense_shape = ctx .make_node ("Concat" , dense_shape_dims , attr = {'axis' : 0 }, op_name_scope = node .name ).output [0 ]
2110
-
2153
+ sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2111
2154
ctx .replace_all_inputs (node .output [0 ], sparse_indices )
2112
2155
ctx .replace_all_inputs (node .output [1 ], dense_values )
2113
2156
ctx .replace_all_inputs (node .output [2 ], dense_shape )
2114
2157
ctx .remove_node (node .name )
2115
2158
2116
2159
2160
+ @tf_op ("RaggedTensorToTensor" )
2161
+ class RaggedTensorToTensor :
2162
+ @classmethod
2163
+ def version_11 (cls , ctx , node , ** kwargs ):
2164
+ shape , values , default_value , * row_partition_tensors = node .input
2165
+ partition_types = node .get_attr_value ("row_partition_types" )
2166
+ error_msg = "Only ROW_SPLITS partition type is supported for RaggedTensorToTensor. types: %r"
2167
+ utils .make_sure (all (t == b'ROW_SPLITS' for t in partition_types ), error_msg , partition_types )
2168
+ nested_splits = row_partition_tensors
2169
+ sparse_indices , dense_shape = ragged_nested_splits_to_sparse_indices (ctx , nested_splits , node .name )
2170
+ # A shape of rank 0 means the natural shape should be used.
2171
+ if ctx .get_rank (shape ) != 0 :
2172
+ if ctx .get_dtype (shape ) != TensorProto .INT64 :
2173
+ shape = ctx .make_node ("Cast" , [shape ], attr = {'to' : TensorProto .INT64 }).output [0 ]
2174
+ const_zero_int64 = ctx .make_const (utils .make_name ("const_zero" ), np .array (0 , dtype = np .int64 )).output [0 ]
2175
+ unspec_dims = ctx .make_node ("Less" , [shape , const_zero_int64 ]).output [0 ]
2176
+ out_shape = ctx .make_node ("Where" , [unspec_dims , dense_shape , shape ]).output [0 ]
2177
+ out_shape_unsq = GraphBuilder (ctx ).make_unsqueeze ({'data' : out_shape , 'axes' : [0 ]})
2178
+ amt_idx_in_bounds = ctx .make_node ("Sub" , [out_shape_unsq , sparse_indices ]).output [0 ]
2179
+ amt_in_bounds_flat = ctx .make_node ("ReduceMin" , [amt_idx_in_bounds ], attr = {'axes' : [1 ], 'keepdims' : False })
2180
+ idx_in_bounds = ctx .make_node ("Greater" , [amt_in_bounds_flat .output [0 ], const_zero_int64 ]).output [0 ]
2181
+ sparse_indices = ctx .make_node ("Compress" , [sparse_indices , idx_in_bounds ], attr = {'axis' : 0 }).output [0 ]
2182
+ values = ctx .make_node ("Compress" , [values , idx_in_bounds ], attr = {'axis' : 0 }).output [0 ]
2183
+ else :
2184
+ out_shape = dense_shape
2185
+ expand_node = ctx .make_node ("Expand" , [default_value , out_shape ])
2186
+ node .type = "ScatterND"
2187
+ ctx .replace_inputs (node , [expand_node .output [0 ], sparse_indices , values ])
2188
+
2189
+
2117
2190
@tf_op ("RaggedRange" )
2118
2191
class RaggedRange :
2119
2192
@classmethod
0 commit comments