Skip to content

Commit 32bad36

Browse files
authored
Merge pull request #596 from mrava87/patch_numpyv2
Feature: migrate to numpy v2.0.0
2 parents a9c7f0f + 6510207 commit 32bad36

13 files changed

+113
-25
lines changed

.readthedocs.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ sphinx:
1818
# Declare the Python requirements required to build your docs
1919
python:
2020
install:
21-
- requirements: requirements-dev.txt
21+
- requirements: requirements-doc.txt
2222
- method: pip
2323
path: .

docs/source/installation.rst

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,11 @@ In alphabetic order:
321321

322322
dtcwt
323323
-----
324+
325+
.. warning::
326+
327+
``dtcwt`` is not yet supported with Numpy 2.
328+
324329
`dtcwt <https://dtcwt.readthedocs.io/en/0.12.0/>`_ is a library used to implement the DT-CWT operators.
325330

326331
Install it via ``pip`` with:
@@ -330,6 +335,7 @@ Install it via ``pip`` with:
330335
>> pip install dtcwt
331336
332337
338+
333339
Devito
334340
------
335341
`Devito <https://github.com/devitocodes/devito>`_ is a library used to solve PDEs via
@@ -468,6 +474,11 @@ or with ``pip`` via
468474
469475
SPGL1
470476
-----
477+
478+
.. warning::
479+
480+
``SPGL1`` is not yet supported with Numpy 2.
481+
471482
`SPGL1 <https://spgl1.readthedocs.io/en/latest/>`_ is used to solve sparsity-promoting
472483
basis pursuit, basis pursuit denoise, and Lasso problems
473484
in :py:func:`pylops.optimization.sparsity.SPGL1` solver.

environment-dev-arm.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ channels:
77
dependencies:
88
- python>=3.6.4
99
- pip
10-
- numpy>=1.21.0,<2.0.0
10+
- numpy>=1.21.0
1111
- scipy>=1.11.0
1212
- pytorch>=1.2.0
1313
- cpuonly

environment-dev.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ channels:
77
dependencies:
88
- python>=3.6.4
99
- pip
10-
- numpy>=1.21.0,<2.0.0
10+
- numpy>=1.21.0
1111
- scipy>=1.11.0
1212
- pytorch>=1.2.0
1313
- cpuonly

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,5 +3,5 @@ channels:
33
- defaults
44
dependencies:
55
- python>=3.6.4
6-
- numpy>=1.21.0,<2.0.0
6+
- numpy>=1.21.0
77
- scipy>=1.14.0

pylops/basicoperators/restriction.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,18 @@
11
__all__ = ["Restriction"]
22

33
import logging
4-
54
from typing import Sequence, Union
65

76
import numpy as np
87
import numpy.ma as np_ma
9-
from numpy.core.multiarray import normalize_axis_index
8+
9+
# need to check numpy version since normalize_axis_index will be
10+
# soon moved from numpy.core.multiarray to from numpy.lib.array_utils
11+
np_version = np.__version__.split(".")
12+
if int(np_version[0]) < 2:
13+
from numpy.core.multiarray import normalize_axis_index
14+
else:
15+
from numpy.lib.array_utils import normalize_axis_index
1016

1117
from pylops import LinearOperator
1218
from pylops.utils._internal import _value_or_sized_to_tuple
@@ -128,8 +134,13 @@ def __init__(
128134
)
129135
forceflat = None
130136

131-
super().__init__(dtype=np.dtype(dtype), dims=dims, dimsd=dimsd,
132-
forceflat=forceflat, name=name)
137+
super().__init__(
138+
dtype=np.dtype(dtype),
139+
dims=dims,
140+
dimsd=dimsd,
141+
forceflat=forceflat,
142+
name=name,
143+
)
133144

134145
iavareshape = np.ones(len(self.dims), dtype=int)
135146
iavareshape[axis] = len(iava)

pylops/linearoperator.py

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1242,23 +1242,14 @@ def _get_dtype(
12421242
) -> DTypeLike:
12431243
if dtypes is None:
12441244
dtypes = []
1245-
opdtypes = []
12461245
for obj in operators:
12471246
if obj is not None and hasattr(obj, "dtype"):
1248-
opdtypes.append(obj.dtype)
1249-
return np.find_common_type(opdtypes, dtypes)
1247+
dtypes.append(obj.dtype)
1248+
return np.result_type(*dtypes)
12501249

12511250

12521251
class _ScaledLinearOperator(LinearOperator):
1253-
"""
1254-
Sum Linear Operator
1255-
1256-
Modified version of scipy _ScaledLinearOperator which uses a modified
1257-
_get_dtype where the scalar and operator types are passed separately to
1258-
np.find_common_type. Passing them together does lead to problems when using
1259-
np.float32 operators which are cast to np.float64
1260-
1261-
"""
1252+
"""Scaled Linear Operator"""
12621253

12631254
def __init__(
12641255
self,
@@ -1269,7 +1260,15 @@ def __init__(
12691260
raise ValueError("LinearOperator expected as A")
12701261
if not np.isscalar(alpha):
12711262
raise ValueError("scalar expected as alpha")
1272-
dtype = _get_dtype([A], [type(alpha)])
1263+
if isinstance(alpha, complex) and not np.iscomplexobj(
1264+
np.ones(1, dtype=A.dtype)
1265+
):
1266+
# if the scalar is of complex type but not the operator, find out type
1267+
dtype = _get_dtype([A], [type(alpha)])
1268+
else:
1269+
# if both the scalar and operator are of real or complex type, use type
1270+
# of the operator
1271+
dtype = A.dtype
12731272
super(_ScaledLinearOperator, self).__init__(dtype=dtype, shape=A.shape)
12741273
self.args = (A, alpha)
12751274

@@ -1465,7 +1464,7 @@ def __init__(self, A: LinearOperator, p: int) -> None:
14651464
if not isintlike(p) or p < 0:
14661465
raise ValueError("non-negative integer expected as p")
14671466

1468-
super(_PowerLinearOperator, self).__init__(dtype=_get_dtype([A]), shape=A.shape)
1467+
super(_PowerLinearOperator, self).__init__(dtype=A.dtype, shape=A.shape)
14691468
self.args = (A, p)
14701469

14711470
def _power(self, fun: Callable, x: NDArray) -> NDArray:

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ classifiers = [
3030
"Topic :: Scientific/Engineering :: Mathematics",
3131
]
3232
dependencies = [
33-
"numpy >= 1.21.0 , < 2.0.0",
33+
"numpy >= 1.21.0",
3434
"scipy >= 1.11.0",
3535
]
3636
dynamic = ["version"]

pytests/test_dtcwt.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
from pylops.signalprocessing import DTCWT
55

6+
# currently test only if numpy<2.0.0 is installed...
7+
np_version = np.__version__.split(".")
8+
69
par1 = {"ny": 10, "nx": 10, "dtype": "float64"}
710
par2 = {"ny": 50, "nx": 50, "dtype": "float64"}
811

@@ -17,6 +20,8 @@ def sequential_array(shape):
1720
@pytest.mark.parametrize("par", [(par1), (par2)])
1821
def test_dtcwt1D_input1D(par):
1922
"""Test for DTCWT with 1D input"""
23+
if int(np_version[0]) >= 2:
24+
return
2025

2126
t = sequential_array((par["ny"],))
2227

@@ -31,6 +36,8 @@ def test_dtcwt1D_input1D(par):
3136
@pytest.mark.parametrize("par", [(par1), (par2)])
3237
def test_dtcwt1D_input2D(par):
3338
"""Test for DTCWT with 2D input (forward-inverse pair)"""
39+
if int(np_version[0]) >= 2:
40+
return
3441

3542
t = sequential_array(
3643
(
@@ -50,6 +57,8 @@ def test_dtcwt1D_input2D(par):
5057
@pytest.mark.parametrize("par", [(par1), (par2)])
5158
def test_dtcwt1D_input3D(par):
5259
"""Test for DTCWT with 3D input (forward-inverse pair)"""
60+
if int(np_version[0]) >= 2:
61+
return
5362

5463
t = sequential_array((par["ny"], par["ny"], par["ny"]))
5564

@@ -64,6 +73,9 @@ def test_dtcwt1D_input3D(par):
6473
@pytest.mark.parametrize("par", [(par1), (par2)])
6574
def test_dtcwt1D_birot(par):
6675
"""Test for DTCWT birot (forward-inverse pair)"""
76+
if int(np_version[0]) >= 2:
77+
return
78+
6779
birots = ["antonini", "legall", "near_sym_a", "near_sym_b"]
6880

6981
t = sequential_array(

pytests/test_sparsity.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,9 @@
55
from pylops.basicoperators import FirstDerivative, Identity, MatrixMult
66
from pylops.optimization.sparsity import fista, irls, ista, omp, spgl1, splitbregman
77

8+
# currently test spgl1 only if numpy<2.0.0 is installed...
9+
np_version = np.__version__.split(".")
10+
811
par1 = {
912
"ny": 11,
1013
"nx": 11,
@@ -359,6 +362,9 @@ def test_ISTA_FISTA_multiplerhs(par):
359362
)
360363
def test_SPGL1(par):
361364
"""Invert problem with SPGL1"""
365+
if int(np_version[0]) >= 2:
366+
return
367+
362368
np.random.seed(42)
363369
Aop = MatrixMult(np.random.randn(par["ny"], par["nx"]))
364370

@@ -412,6 +418,6 @@ def test_SplitBregman(par):
412418
x0=x0 if par["x0"] else None,
413419
restart=False,
414420
show=False,
415-
**dict(iter_lim=5, damp=1e-3)
421+
**dict(iter_lim=5, damp=1e-3),
416422
)
417423
assert (np.linalg.norm(x - xinv) / np.linalg.norm(x)) < 1e-1

0 commit comments

Comments
 (0)