@@ -296,13 +296,12 @@ def L_op(self, inputs, outputs, output_gradients):
296
296
# We need to return (dC/d[inv(A)], dC/db)
297
297
c_bar = output_gradients [0 ]
298
298
299
- trans_solve_op = type (self )(
300
- ** {
301
- k : (not getattr (self , k ) if k == "lower" else getattr (self , k ))
302
- for k in self .__props__
303
- }
304
- )
305
- b_bar = trans_solve_op (A .T , c_bar )
299
+ props_dict = self ._props_dict ()
300
+ props_dict ["lower" ] = not self .lower
301
+
302
+ solve_op = type (self )(** props_dict )
303
+
304
+ b_bar = solve_op (A .T , c_bar )
306
305
# force outer product if vector second input
307
306
A_bar = - ptm .outer (b_bar , c ) if c .ndim == 1 else - b_bar .dot (c .T )
308
307
@@ -385,19 +384,17 @@ class SolveTriangular(SolveBase):
385
384
"""Solve a system of linear equations."""
386
385
387
386
__props__ = (
388
- "trans" ,
389
387
"unit_diagonal" ,
390
388
"lower" ,
391
389
"check_finite" ,
392
390
"b_ndim" ,
393
391
"overwrite_b" ,
394
392
)
395
393
396
- def __init__ (self , * , trans = 0 , unit_diagonal = False , ** kwargs ):
394
+ def __init__ (self , * , unit_diagonal = False , ** kwargs ):
397
395
if kwargs .get ("overwrite_a" , False ):
398
396
raise ValueError ("overwrite_a is not supported for SolverTriangulare" )
399
397
super ().__init__ (** kwargs )
400
- self .trans = trans
401
398
self .unit_diagonal = unit_diagonal
402
399
403
400
def perform (self , node , inputs , outputs ):
@@ -406,7 +403,7 @@ def perform(self, node, inputs, outputs):
406
403
A ,
407
404
b ,
408
405
lower = self .lower ,
409
- trans = self . trans ,
406
+ trans = 0 ,
410
407
unit_diagonal = self .unit_diagonal ,
411
408
check_finite = self .check_finite ,
412
409
overwrite_b = self .overwrite_b ,
@@ -445,9 +442,9 @@ def solve_triangular(
445
442
446
443
Parameters
447
444
----------
448
- a
445
+ a: TensorVariable
449
446
Square input data
450
- b
447
+ b: TensorVariable
451
448
Input data for the right hand side.
452
449
lower : bool, optional
453
450
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
@@ -468,10 +465,17 @@ def solve_triangular(
468
465
This will influence how batched dimensions are interpreted.
469
466
"""
470
467
b_ndim = _default_b_ndim (b , b_ndim )
468
+
469
+ if trans in [1 , "T" , True ]:
470
+ a = a .mT
471
+ lower = not lower
472
+ if trans in [2 , "C" ]:
473
+ a = a .conj ().mT
474
+ lower = not lower
475
+
471
476
ret = Blockwise (
472
477
SolveTriangular (
473
478
lower = lower ,
474
- trans = trans ,
475
479
unit_diagonal = unit_diagonal ,
476
480
check_finite = check_finite ,
477
481
b_ndim = b_ndim ,
@@ -534,6 +538,7 @@ def solve(
534
538
* ,
535
539
assume_a = "gen" ,
536
540
lower = False ,
541
+ transposed = False ,
537
542
check_finite = True ,
538
543
b_ndim : int | None = None ,
539
544
):
@@ -564,8 +569,10 @@ def solve(
564
569
b : (..., N, NRHS) array_like
565
570
Input data for the right hand side.
566
571
lower : bool, optional
567
- If True, only the data contained in the lower triangle of `a`. Default
572
+ If True, use only the data contained in the lower triangle of `a`. Default
568
573
is to use upper triangle. (ignored for ``'gen'``)
574
+ transposed: bool, optional
575
+ If True, solves the system A^T x = b. Default is False.
569
576
check_finite : bool, optional
570
577
Whether to check that the input matrices contain only finite numbers.
571
578
Disabling may give a performance gain, but may result in problems
@@ -577,6 +584,11 @@ def solve(
577
584
This will influence how batched dimensions are interpreted.
578
585
"""
579
586
b_ndim = _default_b_ndim (b , b_ndim )
587
+
588
+ if transposed :
589
+ a = a .mT
590
+ lower = not lower
591
+
580
592
return Blockwise (
581
593
Solve (
582
594
lower = lower ,
0 commit comments