@@ -1242,23 +1242,14 @@ def _get_dtype(
1242
1242
) -> DTypeLike :
1243
1243
if dtypes is None :
1244
1244
dtypes = []
1245
- opdtypes = []
1246
1245
for obj in operators :
1247
1246
if obj is not None and hasattr (obj , "dtype" ):
1248
- opdtypes .append (obj .dtype )
1249
- return np .find_common_type ( opdtypes , dtypes )
1247
+ dtypes .append (obj .dtype )
1248
+ return np .result_type ( * dtypes )
1250
1249
1251
1250
1252
1251
class _ScaledLinearOperator (LinearOperator ):
1253
- """
1254
- Sum Linear Operator
1255
-
1256
- Modified version of scipy _ScaledLinearOperator which uses a modified
1257
- _get_dtype where the scalar and operator types are passed separately to
1258
- np.find_common_type. Passing them together does lead to problems when using
1259
- np.float32 operators which are cast to np.float64
1260
-
1261
- """
1252
+ """Scaled Linear Operator"""
1262
1253
1263
1254
def __init__ (
1264
1255
self ,
@@ -1269,7 +1260,15 @@ def __init__(
1269
1260
raise ValueError ("LinearOperator expected as A" )
1270
1261
if not np .isscalar (alpha ):
1271
1262
raise ValueError ("scalar expected as alpha" )
1272
- dtype = _get_dtype ([A ], [type (alpha )])
1263
+ if isinstance (alpha , complex ) and not np .iscomplexobj (
1264
+ np .ones (1 , dtype = A .dtype )
1265
+ ):
1266
+ # if the scalar is of complex type but not the operator, find out type
1267
+ dtype = _get_dtype ([A ], [type (alpha )])
1268
+ else :
1269
+ # if both the scalar and operator are of real or complex type, use type
1270
+ # of the operator
1271
+ dtype = A .dtype
1273
1272
super (_ScaledLinearOperator , self ).__init__ (dtype = dtype , shape = A .shape )
1274
1273
self .args = (A , alpha )
1275
1274
@@ -1465,7 +1464,7 @@ def __init__(self, A: LinearOperator, p: int) -> None:
1465
1464
if not isintlike (p ) or p < 0 :
1466
1465
raise ValueError ("non-negative integer expected as p" )
1467
1466
1468
- super (_PowerLinearOperator , self ).__init__ (dtype = _get_dtype ([ A ]) , shape = A .shape )
1467
+ super (_PowerLinearOperator , self ).__init__ (dtype = A . dtype , shape = A .shape )
1469
1468
self .args = (A , p )
1470
1469
1471
1470
def _power (self , fun : Callable , x : NDArray ) -> NDArray :
0 commit comments