Skip to content

Commit c2bc6cf

Browse files
authored
Suport keras.op.view() to view the same data bitwise at a new dtype (#21763)
* Support view() in keras to view an array at different dtype. * Update comment with correct view example. * resolve gemini code assist comments * Format lint. * fix lint error and add openvino backend code, which we don't implement at this time. * Add default None to key args of view() to passun-implemented openvino tests. * remove print * skip openvino * Add jax x64 test with enabling x64. * resolve comments * restructure test * format * fix format * skip openvino tests * fix openvino sig * resolve comments * reformat * simplify test * resolve comments convert_to_tensor for all backends * resolve comment simplify return line
1 parent 10b51ce commit c2bc6cf

File tree

12 files changed

+209
-0
lines changed

12 files changed

+209
-0
lines changed

keras/api/_tf_keras/keras/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
from keras.src.ops.numpy import var as var
297297
from keras.src.ops.numpy import vdot as vdot
298298
from keras.src.ops.numpy import vectorize as vectorize
299+
from keras.src.ops.numpy import view as view
299300
from keras.src.ops.numpy import vstack as vstack
300301
from keras.src.ops.numpy import where as where
301302
from keras.src.ops.numpy import zeros as zeros

keras/api/_tf_keras/keras/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
from keras.src.ops.numpy import var as var
183183
from keras.src.ops.numpy import vdot as vdot
184184
from keras.src.ops.numpy import vectorize as vectorize
185+
from keras.src.ops.numpy import view as view
185186
from keras.src.ops.numpy import vstack as vstack
186187
from keras.src.ops.numpy import where as where
187188
from keras.src.ops.numpy import zeros as zeros

keras/api/ops/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,7 @@
296296
from keras.src.ops.numpy import var as var
297297
from keras.src.ops.numpy import vdot as vdot
298298
from keras.src.ops.numpy import vectorize as vectorize
299+
from keras.src.ops.numpy import view as view
299300
from keras.src.ops.numpy import vstack as vstack
300301
from keras.src.ops.numpy import where as where
301302
from keras.src.ops.numpy import zeros as zeros

keras/api/ops/numpy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@
182182
from keras.src.ops.numpy import var as var
183183
from keras.src.ops.numpy import vdot as vdot
184184
from keras.src.ops.numpy import vectorize as vectorize
185+
from keras.src.ops.numpy import view as view
185186
from keras.src.ops.numpy import vstack as vstack
186187
from keras.src.ops.numpy import where as where
187188
from keras.src.ops.numpy import zeros as zeros

keras/src/backend/jax/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -446,6 +446,11 @@ def array(x, dtype=None):
446446
return jnp.array(x, dtype=dtype)
447447

448448

449+
def view(x, dtype=None):
450+
x = convert_to_tensor(x)
451+
return x.view(dtype=dtype)
452+
453+
449454
def average(x, axis=None, weights=None):
450455
x = convert_to_tensor(x)
451456
dtypes_to_resolve = [x.dtype, float]

keras/src/backend/numpy/numpy.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,11 @@ def array(x, dtype=None):
294294
return convert_to_tensor(x, dtype=dtype)
295295

296296

297+
def view(x, dtype=None):
298+
x = convert_to_tensor(x)
299+
return x.view(dtype=dtype)
300+
301+
297302
def average(x, axis=None, weights=None):
298303
axis = standardize_axis_for_numpy(axis)
299304
x = convert_to_tensor(x)

keras/src/backend/openvino/excluded_concrete_tests.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,7 @@ NumpyDtypeTest::test_trunc
5151
NumpyDtypeTest::test_unravel
5252
NumpyDtypeTest::test_var
5353
NumpyDtypeTest::test_vdot
54+
NumpyDtypeTest::test_view
5455
NumpyDtypeTest::test_vstack
5556
HistogramTest
5657
NumpyOneInputOpsCorrectnessTest::test_angle
@@ -102,6 +103,7 @@ NumpyOneInputOpsCorrectnessTest::test_unravel_index
102103
NumpyOneInputOpsCorrectnessTest::test_var
103104
NumpyOneInputOpsCorrectnessTest::test_vectorize
104105
NumpyOneInputOpsCorrectnessTest::test_vstack
106+
NumpyOneInputOpsCorrectnessTest::test_view
105107
NumpyTwoInputOpsCorrectnessTest::test_bitwise_and
106108
NumpyTwoInputOpsCorrectnessTest::test_bitwise_left_shift
107109
NumpyTwoInputOpsCorrectnessTest::test_bitwise_or
@@ -131,10 +133,12 @@ NumpyOneInputOpsDynamicShapeTest::test_hanning
131133
NumpyOneInputOpsDynamicShapeTest::test_isposinf
132134
NumpyOneInputOpsDynamicShapeTest::test_isreal
133135
NumpyOneInputOpsDynamicShapeTest::test_kaiser
136+
NumpyOneInputOpsDynamicShapeTest::test_view
134137
NumpyOneInputOpsStaticShapeTest::test_angle
135138
NumpyOneInputOpsStaticShapeTest::test_cbrt
136139
NumpyOneInputOpsStaticShapeTest::test_isposinf
137140
NumpyOneInputOpsStaticShapeTest::test_isreal
141+
NumpyOneInputOpsStaticShapeTest::test_view
138142
NumpyTwoInputOpsDynamicShapeTest::test_gcd
139143
NumpyTwoInputOpsDynamicShapeTest::test_heaviside
140144
NumpyTwoInputOpsDynamicShapeTest::test_hypot

keras/src/backend/openvino/numpy.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -508,6 +508,10 @@ def array(x, dtype=None):
508508
return np.array(x)
509509

510510

511+
def view(x, dtype=None):
512+
raise NotImplementedError("`view` is not supported with openvino backend")
513+
514+
511515
def average(x, axis=None, weights=None):
512516
x = get_ov_output(x)
513517
if weights is not None:

keras/src/backend/tensorflow/numpy.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -998,6 +998,49 @@ def array(x, dtype=None):
998998
return convert_to_tensor(x, dtype=dtype)
999999

10001000

1001+
def view(x, dtype=None):
1002+
from keras.src import backend
1003+
1004+
x = convert_to_tensor(x)
1005+
old_dtype = tf.as_dtype(backend.standardize_dtype(x.dtype))
1006+
new_dtype = tf.as_dtype(
1007+
backend.standardize_dtype(dtype if dtype else x.dtype)
1008+
)
1009+
1010+
old_itemsize = old_dtype.size
1011+
new_itemsize = new_dtype.size
1012+
1013+
if list(x.shape)[-1] * old_itemsize % new_itemsize != 0:
1014+
raise ValueError(
1015+
f"Cannot view array of shape {x.shape} and dtype {old_dtype} "
1016+
f"as dtype {new_dtype} because the total number of bytes "
1017+
f"is not divisible by the new itemsize."
1018+
)
1019+
1020+
if old_itemsize == new_itemsize:
1021+
return tf.bitcast(x, type=new_dtype)
1022+
elif old_itemsize > new_itemsize:
1023+
ratio = old_itemsize // new_itemsize
1024+
new_shape = list(shape_op(x))
1025+
new_shape[-1] *= ratio
1026+
flat_tensor = tf.reshape(x, [-1])
1027+
cast_tensor = tf.bitcast(flat_tensor, type=new_dtype)
1028+
return tf.reshape(cast_tensor, new_shape)
1029+
else:
1030+
old_shape = list(shape_op(x))
1031+
last_dim_size = old_shape[-1]
1032+
ratio = new_itemsize // old_itemsize
1033+
if isinstance(last_dim_size, int) and last_dim_size % ratio != 0:
1034+
raise ValueError(
1035+
f"Cannot view dtype. Last dimension size ({last_dim_size}) "
1036+
f"must be divisible by the ratio of new/old item sizes "
1037+
f"({ratio})."
1038+
)
1039+
intermediate_shape = old_shape[:-1] + [last_dim_size // ratio, ratio]
1040+
reshaped_tensor = tf.reshape(x, intermediate_shape)
1041+
return tf.bitcast(reshaped_tensor, new_dtype)
1042+
1043+
10011044
def average(x, axis=None, weights=None):
10021045
x = convert_to_tensor(x)
10031046

keras/src/backend/torch/numpy.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -411,6 +411,12 @@ def array(x, dtype=None):
411411
return convert_to_tensor(x, dtype=dtype)
412412

413413

414+
def view(x, dtype=None):
415+
dtype = to_torch_dtype(dtype)
416+
x = convert_to_tensor(x)
417+
return x.view(dtype=dtype)
418+
419+
414420
def average(x, axis=None, weights=None):
415421
x = convert_to_tensor(x)
416422
dtypes_to_resolve = [x.dtype, float]

0 commit comments

Comments
 (0)