@@ -2485,26 +2485,46 @@ def test_batch_to_spacend(self):
24852485 self ._run_test_case ([_OUTPUT ], {_INPUT : input_val })
24862486
24872487 @check_opset_min_version (11 , "BatchToSpaceND" )
2488- def test_batch_to_spacend_non_const (self ):
2489- input_x_val = np .random .random_sample ([40 , 3 , 5 , 100 ]).astype (np .float32 ) # NHWC
2490- block_shape_val = np .array ([2 , 2 ]).astype (np .int64 )
2491- crops_val = np .array ([[1 , 0 ], [2 , 1 ]]).astype (np .int64 )
2492- input_x = tf .placeholder (dtype = tf .float32 , shape = input_x_val .shape , name = _TFINPUT )
2493- block_shape = tf .placeholder (dtype = tf .int64 , shape = block_shape_val .shape , name = _TFINPUT1 )
2494- crops = tf .placeholder (dtype = tf .int64 , shape = crops_val .shape , name = _TFINPUT2 )
2495- _ = tf .batch_to_space_nd (input_x , block_shape , crops , name = _TFOUTPUT )
2496- self ._run_test_case ([_OUTPUT ], {_INPUT : input_x_val , _INPUT1 : block_shape_val , _INPUT2 : crops_val })
2488+ def test_batch_to_spacend_non_const_7d (self ):
2489+ x_type , y_type , z_type = np .int64 , np .int64 , np .int64
2490+ # test 3D upto 7D input tensors
2491+ for x_shape in [[12 , 4 , 4 ], [12 , 4 , 8 , 3 ], [12 , 4 , 8 , 3 , 2 ], [12 , 4 , 8 , 3 , 2 , 3 ], [12 , 4 , 8 , 3 , 2 , 1 , 3 ]]:
2492+ # test 1D upto 2D block shapes
2493+ for block_shape in [[2 , 3 ], [2 ]]:
2494+ tf .reset_default_graph ()
2495+ # crop 1 layer at end of each dim
2496+ crops = [[0 , 1 ] for dim in block_shape ]
2497+ y_val = np .array (block_shape ).astype (y_type )
2498+ x_val = np .array ([x + 1 for x in range (0 , np .prod (x_shape ))], dtype = x_type ).reshape (x_shape )
2499+ z_val = np .array (crops ).astype (z_type )
2500+ # x and z can be dynamic.
2501+ # y = block_shape cannot be dynamic without change to Transpose op spec
2502+ x = tf .placeholder (dtype = x_type , shape = x_val .shape , name = _TFINPUT )
2503+ y = tf .constant (dtype = y_type , value = y_val , shape = y_val .shape , name = _TFINPUT1 )
2504+ z = tf .placeholder (dtype = z_type , shape = z_val .shape , name = _TFINPUT2 )
2505+ _ = tf .batch_to_space_nd (x , y , z , name = _TFOUTPUT )
2506+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val , _INPUT2 : z_val })
24972507
24982508 @check_opset_min_version (11 , "SpaceToBatchND" )
2499- def test_space_to_batchnd_non_const (self ):
2500- input_x_val = np .random .random_sample ([40 , 5 , 7 , 66 ]).astype (np .float32 ) # NHWC
2501- block_size_val = np .array ([2 , 2 ]).astype (np .int64 )
2502- pad_val = np .array ([[0 , 1 ], [2 , 1 ]]).astype (np .int64 )
2503- input_x = tf .placeholder (dtype = tf .float32 , shape = input_x_val .shape , name = _TFINPUT )
2504- block_size = tf .placeholder (dtype = tf .int64 , shape = block_size_val .shape , name = _TFINPUT1 )
2505- pad = tf .placeholder (dtype = tf .int64 , shape = pad_val .shape , name = _TFINPUT2 )
2506- _ = tf .space_to_batch_nd (input_x , block_size , pad , name = _TFOUTPUT )
2507- self ._run_test_case ([_OUTPUT ], {_INPUT : input_x_val , _INPUT1 : block_size_val , _INPUT2 : pad_val })
2509+ def test_space_to_batchnd_non_const_7d (self ):
2510+ x_type , y_type , z_type = np .int64 , np .int64 , np .int64
2511+ # test 3D upto 7D input tensors
2512+ for x_shape in [[2 , 4 , 4 ], [1 , 4 , 8 , 3 ], [1 , 4 , 8 , 3 , 2 ], [1 , 4 , 8 , 3 , 2 , 3 ], [1 , 4 , 8 , 3 , 2 , 1 , 3 ]]:
2513+ # test 1D upto 2D block shapes
2514+ for block_shape in [[2 ], [2 , 2 ]]:
2515+ tf .reset_default_graph ()
2516+ # pad 1 layer at begin and end of each dim
2517+ pads = [[1 , 1 ] for dim in block_shape ]
2518+ y_val = np .array (block_shape ).astype (y_type )
2519+ x_val = np .array ([x + 1 for x in range (0 , np .prod (x_shape ))], dtype = x_type ).reshape (x_shape )
2520+ z_val = np .array (pads ).astype (z_type )
2521+ # x and z can be dynamic.
2522+ # y = block_shape cannot be dynamic without change to Transpose op spec
2523+ x = tf .placeholder (dtype = x_type , shape = x_val .shape , name = _TFINPUT )
2524+ y = tf .constant (dtype = y_type , value = y_val , shape = y_val .shape , name = _TFINPUT1 )
2525+ z = tf .placeholder (dtype = z_type , shape = z_val .shape , name = _TFINPUT2 )
2526+ _ = tf .space_to_batch_nd (x , y , z , name = _TFOUTPUT )
2527+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val , _INPUT2 : z_val })
25082528
25092529 @check_opset_min_version (11 , "CropAndResize" )
25102530 def test_crop_and_resize_linear (self ):
0 commit comments