@@ -1548,6 +1548,28 @@ def test_reverse_sequence_time_major(self):
15481548
15491549 @check_opset_min_version (8 , "where" )
15501550 def test_where (self ):
1551+ x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .float32 )
1552+ true_result = np .array ([111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ],
1553+ dtype = np .float32 )
1554+ false_result = np .array ([- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ],
1555+ dtype = np .float32 )
1556+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
1557+ picks = tf .where (x > - 1 , true_result , false_result )
1558+ _ = tf .identity (picks , name = _TFOUTPUT )
1559+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1560+
1561+ tf .reset_default_graph ()
1562+ x_val = np .array (1 , dtype = np .float32 )
1563+ true_result = np .array (100 , dtype = np .float32 )
1564+ false_result = np .array (- 111 , dtype = np .float32 )
1565+ x = tf .placeholder (tf .float32 , x_val .shape , name = _TFINPUT )
1566+ picks = tf .where (x > - 1 , true_result , false_result )
1567+ _ = tf .identity (picks , name = _TFOUTPUT )
1568+ self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
1569+
1570+ @check_opset_min_version (8 , "where" )
1571+ @check_target ("rs6" , "onnxruntime Where type limitation" )
1572+ def test_where_int32 (self ):
15511573 x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .int32 )
15521574 true_result = np .array ([111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ],
15531575 dtype = np .int32 )
@@ -1560,59 +1582,59 @@ def test_where(self):
15601582
15611583 @check_opset_min_version (8 , "where" )
15621584 def test_where_with_two_rank_input (self ):
1563- x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .int32 )
1585+ x_val = np .array ([1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ], dtype = np .float32 )
15641586 true_result = np .array ([[111 , 111 ], [222 , 222 ], [333 , 333 ], [444 , 444 ], [555 , 555 ],
15651587 [666 , 666 ], [777 , 777 ], [888 , 888 ], [999 , 999 ], [1000 , 1000 ]],
1566- dtype = np .int32 )
1588+ dtype = np .float32 )
15671589 false_result = np .array ([[- 111 , - 111 ], [- 222 , - 222 ], [- 333 , - 333 ], [- 444 , - 444 ],
15681590 [- 555 , - 555 ], [- 666 , - 666 ], [- 777 , - 777 ], [- 888 , - 888 ],
15691591 [- 999 , - 999 ], [- 1000 , - 1000 ]],
1570- dtype = np .int32 )
1571- x = tf .placeholder (tf .int32 , [None ], name = _TFINPUT )
1592+ dtype = np .float32 )
1593+ x = tf .placeholder (tf .float32 , [None ], name = _TFINPUT )
15721594 picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
15731595 _ = tf .identity (picks , name = _TFOUTPUT )
15741596
15751597 self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
15761598
15771599 @check_opset_min_version (8 , "where" )
15781600 def test_where_with_two_rank_condition (self ):
1579- x_val = np .array ([[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]], dtype = np .int32 )
1601+ x_val = np .array ([[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]], dtype = np .float32 )
15801602 true_result = np .array ([[111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ]],
1581- dtype = np .int32 )
1603+ dtype = np .float32 )
15821604 false_result = np .array ([[- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ]],
1583- dtype = np .int32 )
1584- x = tf .placeholder (tf .int32 , [1 , 10 ], name = _TFINPUT )
1605+ dtype = np .float32 )
1606+ x = tf .placeholder (tf .float32 , [1 , 10 ], name = _TFINPUT )
15851607 picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
15861608 _ = tf .identity (picks , name = _TFOUTPUT )
15871609
15881610 self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
15891611
15901612 @check_opset_min_version (8 , "where" )
15911613 def test_where_with_three_rank_condition (self ):
1592- x_val = np .array ([[[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]]], dtype = np .int32 )
1614+ x_val = np .array ([[[1 , 2 , - 3 , 4 , - 5 , - 6 , - 7 , 8 , 9 , 0 ]]], dtype = np .float32 )
15931615 true_result = np .array ([[[111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ]]],
1594- dtype = np .int32 )
1616+ dtype = np .float32 )
15951617 false_result = np .array ([[[- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ]]],
1596- dtype = np .int32 )
1597- x = tf .placeholder (tf .int32 , [1 , 1 , 10 ], name = _TFINPUT )
1618+ dtype = np .float32 )
1619+ x = tf .placeholder (tf .float32 , [1 , 1 , 10 ], name = _TFINPUT )
15981620 picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
15991621 _ = tf .identity (picks , name = _TFOUTPUT )
16001622
16011623 self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
16021624
16031625 @check_opset_min_version (8 , "where" )
16041626 def test_where_scalar (self ):
1605- x_val = np .array (6 , dtype = np .int32 )
1627+ x_val = np .array (6 , dtype = np .float32 )
16061628 true_result = np .array ([111 , 222 , 333 , 444 , 555 , 666 , 777 , 888 , 999 , 1000 ],
1607- dtype = np .int32 )
1629+ dtype = np .float32 )
16081630 false_result = np .array ([- 111 , - 222 , - 333 , - 444 , - 555 , - 666 , - 777 , - 888 , - 999 , - 1000 ],
1609- dtype = np .int32 )
1610- x = tf .placeholder (tf .int32 , [], name = _TFINPUT )
1631+ dtype = np .float32 )
1632+ x = tf .placeholder (tf .float32 , [], name = _TFINPUT )
16111633 picks = tf .where (tf .greater_equal (x , 0 ), true_result , false_result )
16121634 _ = tf .identity (picks , name = _TFOUTPUT )
16131635 self ._run_test_case ([_OUTPUT ], {_INPUT : x_val })
16141636
1615- @check_opset_min_version (9 , "where " )
1637+ @check_opset_min_version (9 , "NonZero " )
16161638 @check_target ("rs6" , "onnxruntime Transpose type limitation" )
16171639 def test_where_with_cond_only (self ):
16181640 for np_type , tf_type in [(np .int32 , tf .int32 ), (np .float32 , tf .float32 )]:
0 commit comments