Skip to content

Commit 617964f

Browse files
Refactor and update QR Op (#1518)
* Refactor QR * Update JAX QR dispatch * Update Torch QR dispatch * Update numba QR dispatch
1 parent 5024d54 commit 617964f

File tree

20 files changed

+1703
-424
lines changed

20 files changed

+1703
-424
lines changed

pytensor/link/jax/dispatch/nlinalg.py

Lines changed: 0 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
KroneckerProduct,
1010
MatrixInverse,
1111
MatrixPinv,
12-
QRFull,
1312
SLogDet,
1413
)
1514

@@ -67,16 +66,6 @@ def matrix_inverse(x):
6766
return matrix_inverse
6867

6968

70-
@jax_funcify.register(QRFull)
71-
def jax_funcify_QRFull(op, **kwargs):
72-
mode = op.mode
73-
74-
def qr_full(x, mode=mode):
75-
return jnp.linalg.qr(x, mode=mode)
76-
77-
return qr_full
78-
79-
8069
@jax_funcify.register(MatrixPinv)
8170
def jax_funcify_Pinv(op, **kwargs):
8271
def pinv(x):

pytensor/link/jax/dispatch/slinalg.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pytensor.link.jax.dispatch.basic import jax_funcify
66
from pytensor.tensor.slinalg import (
77
LU,
8+
QR,
89
BlockDiagonal,
910
Cholesky,
1011
CholeskySolve,
@@ -168,3 +169,13 @@ def cho_solve(c, b):
168169
)
169170

170171
return cho_solve
172+
173+
174+
@jax_funcify.register(QR)
175+
def jax_funcify_QR(op, **kwargs):
176+
mode = op.mode
177+
178+
def qr(x, mode=mode):
179+
return jax.scipy.linalg.qr(x, mode=mode)
180+
181+
return qr

pytensor/link/numba/dispatch/linalg/_LAPACK.py

Lines changed: 87 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -283,7 +283,6 @@ def numba_xgetrs(cls, dtype):
283283
284284
Called by scipy.linalg.lu_solve
285285
"""
286-
...
287286
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "getrs")
288287
functype = ctypes.CFUNCTYPE(
289288
None,
@@ -457,3 +456,90 @@ def numba_xgtcon(cls, dtype):
457456
_ptr_int, # INFO
458457
)
459458
return functype(lapack_ptr)
459+
460+
@classmethod
461+
def numba_xgeqrf(cls, dtype):
462+
"""
463+
Compute the QR factorization of a general M-by-N matrix A.
464+
465+
Used in QR decomposition (no pivoting).
466+
"""
467+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqrf")
468+
functype = ctypes.CFUNCTYPE(
469+
None,
470+
_ptr_int, # M
471+
_ptr_int, # N
472+
float_pointer, # A
473+
_ptr_int, # LDA
474+
float_pointer, # TAU
475+
float_pointer, # WORK
476+
_ptr_int, # LWORK
477+
_ptr_int, # INFO
478+
)
479+
return functype(lapack_ptr)
480+
481+
@classmethod
482+
def numba_xgeqp3(cls, dtype):
483+
"""
484+
Compute the QR factorization with column pivoting of a general M-by-N matrix A.
485+
486+
Used in QR decomposition with pivoting.
487+
"""
488+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "geqp3")
489+
functype = ctypes.CFUNCTYPE(
490+
None,
491+
_ptr_int, # M
492+
_ptr_int, # N
493+
float_pointer, # A
494+
_ptr_int, # LDA
495+
_ptr_int, # JPVT
496+
float_pointer, # TAU
497+
float_pointer, # WORK
498+
_ptr_int, # LWORK
499+
_ptr_int, # INFO
500+
)
501+
return functype(lapack_ptr)
502+
503+
@classmethod
504+
def numba_xorgqr(cls, dtype):
505+
"""
506+
Generate the orthogonal matrix Q from a QR factorization (real types).
507+
508+
Used in QR decomposition to form Q.
509+
"""
510+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "orgqr")
511+
functype = ctypes.CFUNCTYPE(
512+
None,
513+
_ptr_int, # M
514+
_ptr_int, # N
515+
_ptr_int, # K
516+
float_pointer, # A
517+
_ptr_int, # LDA
518+
float_pointer, # TAU
519+
float_pointer, # WORK
520+
_ptr_int, # LWORK
521+
_ptr_int, # INFO
522+
)
523+
return functype(lapack_ptr)
524+
525+
@classmethod
526+
def numba_xungqr(cls, dtype):
527+
"""
528+
Generate the unitary matrix Q from a QR factorization (complex types).
529+
530+
Used in QR decomposition to form Q for complex types.
531+
"""
532+
lapack_ptr, float_pointer = _get_lapack_ptr_and_ptr_type(dtype, "ungqr")
533+
functype = ctypes.CFUNCTYPE(
534+
None,
535+
_ptr_int, # M
536+
_ptr_int, # N
537+
_ptr_int, # K
538+
float_pointer, # A
539+
_ptr_int, # LDA
540+
float_pointer, # TAU
541+
float_pointer, # WORK
542+
_ptr_int, # LWORK
543+
_ptr_int, # INFO
544+
)
545+
return functype(lapack_ptr)

0 commit comments

Comments
 (0)