Skip to content

Commit 64c959a

Browse files
committed
FEAT: Adding clamp function and relevant tests
1 parent 31ffd98 commit 64c959a

File tree

2 files changed

+43
-0
lines changed

2 files changed

+43
-0
lines changed

arrayfire/arith.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -126,6 +126,44 @@ def maxof(lhs, rhs):
126126
"""
127127
return _arith_binary_func(lhs, rhs, backend.get().af_maxof)
128128

129+
def clamp(val, low, high):
130+
"""
131+
Clamp the input value between low and high
132+
133+
134+
Parameters
135+
----------
136+
val : af.Array
137+
Multi dimensional arrayfire array to be clamped.
138+
139+
low : af.Array or scalar
140+
Multi dimensional arrayfire array or a scalar number denoting the lower value(s).
141+
142+
high : af.Array or scalar
143+
Multi dimensional arrayfire array or a scalar number denoting the higher value(s).
144+
"""
145+
out = Array()
146+
147+
is_low_array = isinstance(low, Array)
148+
is_high_array = isinstance(high, Array)
149+
150+
vdims = dim4_to_tuple(val.dims())
151+
vty = val.type()
152+
153+
if not is_low_array:
154+
low_arr = constant_array(low, vdims[0], vdims[1], vdims[2], vdims[3], vty)
155+
else:
156+
low_arr = low.arr
157+
158+
if not is_high_array:
159+
high_arr = constant_array(high, vdims[0], vdims[1], vdims[2], vdims[3], vty)
160+
else:
161+
high_arr = high.arr
162+
163+
safe_call(backend.get().af_clamp(ct.pointer(out.arr), val.arr, low_arr, high_arr, _bcast_var.get()))
164+
165+
return out
166+
129167
def rem(lhs, rhs):
130168
"""
131169
Find the remainder.

arrayfire/tests/simple/arith.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,11 @@ def simple_arith(verbose = False):
134134
display_func(af.cast(a, af.Dtype.c32))
135135
display_func(af.maxof(a,b))
136136
display_func(af.minof(a,b))
137+
138+
display_func(af.clamp(a, 0, 1))
139+
display_func(af.clamp(a, 0, b))
140+
display_func(af.clamp(a, b, 1))
141+
137142
display_func(af.rem(a,b))
138143

139144
a = af.randu(3,3) - 0.5

0 commit comments

Comments
 (0)