Skip to content

Commit 01bd2dc

Browse files
authored
fix documentation (#102)
1 parent ac4acc6 commit 01bd2dc

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

onnx_array_api/reference/ops/op_cast_like.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717

1818
def _cast_like(x, y, saturate):
1919
if bfloat16 is None:
20-
return (cast_to(x, y.dtype, saturate),)
20+
to = np_dtype_to_tensor_dtype(y.dtype)
21+
return (cast_to(x, to, saturate),)
2122
if y.dtype == bfloat16 and y.dtype.descr[0][0] == "bfloat16":
2223
# np.uint16 == np.uint16 is True as well as np.uint16 == bfloat16
2324
to = TensorProto.BFLOAT16

0 commit comments

Comments
 (0)