Skip to content

Commit d192a37

Browse files
committed
BF: Set strided_scalar as not writeable
Numpy 1.10 cannot broadcast a strided array that is set as writeable
1 parent f70961a commit d192a37

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

nibabel/fileslice.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -767,9 +767,15 @@ def strided_scalar(shape, scalar=0.):
767767
strided_arr : array
768768
Array of shape `shape` for which all values == `scalar`, built by
769769
setting all strides of `strided_arr` to 0, so the scalar is broadcast
770-
out to the full array `shape`.
770+
out to the full array `shape`. `strided_arr` is flagged as not
771+
`writeable`.
772+
773+
The resulting array is set read-only to avoid a numpy error when
774+
broadcasting - see https://github.com/numpy/numpy/issues/6491
771775
"""
772776
shape = tuple(shape)
773777
scalar = np.array(scalar)
774778
strides = [0] * len(shape)
775-
return np.lib.stride_tricks.as_strided(scalar, shape, strides)
779+
strided_scalar = np.lib.stride_tricks.as_strided(scalar, shape, strides)
780+
strided_scalar.flags.writeable = False
781+
return strided_scalar

nibabel/tests/test_fileslice.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -639,8 +639,14 @@ def test_strided_scalar():
639639
assert_equal(observed.shape, shape)
640640
assert_equal(observed.dtype, expected.dtype)
641641
assert_array_equal(observed.strides, 0)
642-
observed[..., 0] = 99
643-
assert_array_equal(observed, expected * 0 + 99)
642+
# Strided scalars are set as not writeable
643+
# This addresses a numpy 1.10 breakage of broadcasting a strided
644+
# array without resizing (see GitHub PR #358)
645+
assert_false(observed.flags.writeable)
646+
def setval(x):
647+
x[..., 0] = 99
648+
# RuntimeError for numpy < 1.10
649+
assert_raises((RuntimeError, ValueError), setval, observed)
644650
# Default scalar value is 0
645651
assert_array_equal(strided_scalar((2, 3, 4)), np.zeros((2, 3, 4)))
646652

0 commit comments

Comments
 (0)