Skip to content

Refactor and update QR Op #1518

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 4 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 0 additions & 11 deletions pytensor/link/jax/dispatch/nlinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
KroneckerProduct,
MatrixInverse,
MatrixPinv,
QRFull,
SLogDet,
)

Expand Down Expand Up @@ -67,16 +66,6 @@ def matrix_inverse(x):
return matrix_inverse


@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op, **kwargs):
mode = op.mode

def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)

return qr_full


@jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs):
def pinv(x):
Expand Down
11 changes: 11 additions & 0 deletions pytensor/link/jax/dispatch/slinalg.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from pytensor.link.jax.dispatch.basic import jax_funcify
from pytensor.tensor.slinalg import (
LU,
QR,
BlockDiagonal,
Cholesky,
CholeskySolve,
Expand Down Expand Up @@ -168,3 +169,13 @@ def cho_solve(c, b):
)

return cho_solve


@jax_funcify.register(QR)
def jax_funcify_QR(op, **kwargs):
mode = op.mode

def qr(x, mode=mode):
return jax.scipy.linalg.qr(x, mode=mode)

return qr
88 changes: 87 additions & 1 deletion pytensor/link/numba/dispatch/linalg/_LAPACK.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,7 +283,6 @@ def numba_xgetrs(cls, dtype):

Called by scipy.linalg.lu_solve
"""
...
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
functype = ctypes.CFUNCTYPE(
None,
Expand Down Expand Up @@ -457,3 +456,90 @@ def numba_xgtcon(cls, dtype):
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xgeqrf(cls, dtype):
"""
Compute the QR factorization of a general M-by-N matrix A.

Used in QR decomposition (no pivoting).
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xgeqp3(cls, dtype):
"""
Compute the QR factorization with column pivoting of a general M-by-N matrix A.

Used in QR decomposition with pivoting.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
float_pointer, # A
_ptr_int, # LDA
_ptr_int, # JPVT
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xorgqr(cls, dtype):
"""
Generate the orthogonal matrix Q from a QR factorization (real types).

Used in QR decomposition to form Q.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)

@classmethod
def numba_xungqr(cls, dtype):
"""
Generate the unitary matrix Q from a QR factorization (complex types).

Used in QR decomposition to form Q for complex types.
"""
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr")
functype = ctypes.CFUNCTYPE(
None,
_ptr_int, # M
_ptr_int, # N
_ptr_int, # K
float_pointer, # A
_ptr_int, # LDA
float_pointer, # TAU
float_pointer, # WORK
_ptr_int, # LWORK
_ptr_int, # INFO
)
return functype(lapack_ptr)
Loading
Loading