diff --git a/doc/conf.py b/doc/conf.py index 1729efc4b1..e10dcffb90 100644 --- a/doc/conf.py +++ b/doc/conf.py @@ -1,6 +1,7 @@ import os import inspect import sys + import pytensor from pathlib import Path @@ -12,6 +13,7 @@ "sphinx.ext.autodoc", "sphinx.ext.todo", "sphinx.ext.doctest", + "sphinx_copybutton", "sphinx.ext.napoleon", "sphinx.ext.linkcode", "sphinx.ext.mathjax", @@ -86,8 +88,7 @@ # List of directories, relative to source directories, that shouldn't be # searched for source files. -exclude_dirs = ["images", "scripts", "sandbox"] -exclude_patterns = ['page_footer.md', '**/*.myst.md'] +exclude_patterns = ["README.md", "images/*", "page_footer.md", "**/*.myst.md"] # The reST default role (used for this markup: `text`) to use for all # documents. @@ -235,24 +236,41 @@ # Resolve function # This function is used to populate the (source) links in the API def linkcode_resolve(domain, info): - def find_source(): + def find_obj() -> object: # try to find the file and line number, based on code from numpy: # https://github.com/numpy/numpy/blob/master/doc/source/conf.py#L286 obj = sys.modules[info["module"]] for part in info["fullname"].split("."): obj = getattr(obj, part) + return obj + def find_source(obj): fn = Path(inspect.getsourcefile(obj)) - fn = fn.relative_to(Path(__file__).parent) + fn = fn.relative_to(Path(pytensor.__file__).parent) source, lineno = inspect.getsourcelines(obj) return fn, lineno, lineno + len(source) - 1 + def fallback_source(): + return info["module"].replace(".", "/") + ".py" + if domain != "py" or not info["module"]: return None + try: - filename = "pytensor/%s#L%d-L%d" % find_source() + obj = find_obj() except Exception: - filename = info["module"].replace(".", "/") + ".py" + filename = fallback_source() + else: + try: + filename = "pytensor/%s#L%d-L%d" % find_source(obj) + except Exception: + # warnings.warn(f"Could not find source code for {domain}:{info}") + try: + filename = obj.__module__.replace(".", "/") + ".py" + except AttributeError: + # Some objects do not have a __module__ attribute (?) + filename = fallback_source() + import subprocess tag = subprocess.Popen( diff --git a/doc/environment.yml b/doc/environment.yml index d58af79cc6..7b564e8fb0 100644 --- a/doc/environment.yml +++ b/doc/environment.yml @@ -13,7 +13,9 @@ dependencies: - mock - pillow - pymc-sphinx-theme + - sphinx-copybutton - sphinx-design + - sphinx-sitemap - pygments - pydot - ipython @@ -23,5 +25,4 @@ dependencies: - ablog - pip - pip: - - sphinx_sitemap - -e .. diff --git a/doc/library/index.rst b/doc/library/index.rst index 08a5b51c34..e9b362f8db 100644 --- a/doc/library/index.rst +++ b/doc/library/index.rst @@ -20,14 +20,12 @@ Modules d3viz/index graph/index gradient - misc/pkl_utils printing - scalar/index scan sparse/index - sparse/sandbox tensor/index typed_list + xtensor/index .. module:: pytensor :platform: Unix, Windows diff --git a/doc/library/xtensor/index.md b/doc/library/xtensor/index.md new file mode 100644 index 0000000000..3ebb852773 --- /dev/null +++ b/doc/library/xtensor/index.md @@ -0,0 +1,101 @@ +(libdoc_xtensor)= +# `xtensor` -- XTensor operations + +This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. + +A new type {class}`pytensor.xtensor.type.XTensorType`, generalizes the {class}`pytensor.tensor.TensorType` +with the addition of a `dims` attribute, that labels the dimensions of the tensor. + +Variables of XTensorType (i.e., {class}`pytensor.xtensor.type.XTensorVariable`s) are the symbolic counterpart +to xarray DataArray objects. + +The module implements several PyTensor operations {class}`pytensor.xtensor.basic.XOp`s, whose signature mimics that of +xarray (and xarray_einstats) DataArray operations. These operations, unlike most regular PyTensor operations, cannot +be directly evaluated, but require a rewrite (lowering) into a regular tensor graph that can itself be evaluated as usual. + +Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. +If the existing XOps can be composed to produce the desired result, then we can use them directly. + +## Coordinates +For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. +The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. +Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. + +## Example + + +```{testcode} + +import pytensor.tensor as pt +import pytensor.xtensor as ptx + +a = pt.tensor("a", shape=(3,)) +b = pt.tensor("b", shape=(4,)) + +ax = ptx.as_xtensor(a, dims=["x"]) +bx = ptx.as_xtensor(b, dims=["y"]) + +zx = ax + bx +assert zx.type == ptx.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) + +z = zx.values +z.dprint() +``` + + +```{testoutput} + +TensorFromXTensor [id A] + └─ XElemwise{scalar_op=Add()} [id B] + ├─ XTensorFromTensor{dims=('x',)} [id C] + │ └─ a [id D] + └─ XTensorFromTensor{dims=('y',)} [id E] + └─ b [id F] +``` + +Once we compile the graph, no XOps are left. + +```{testcode} + +import pytensor + +with pytensor.config.change_flags(optimizer_verbose=True): + fn = pytensor.function([a, b], z) + +``` + +```{testoutput} + +rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) +rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None +rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None +rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) + +``` + +```{testcode} + +fn.dprint() +``` + +```{testoutput} + +Add [id A] 2 + ├─ ExpandDims{axis=1} [id B] 1 + │ └─ a [id C] + └─ ExpandDims{axis=0} [id D] 0 + └─ b [id E] +``` + + +## Index + +:::{toctree} +:maxdepth: 1 + +module_functions +math +linalg +random +type +::: \ No newline at end of file diff --git a/doc/library/xtensor/linalg.md b/doc/library/xtensor/linalg.md new file mode 100644 index 0000000000..3861be1398 --- /dev/null +++ b/doc/library/xtensor/linalg.md @@ -0,0 +1,7 @@ +(libdoc_xtensor_linalg)= +# `xtensor.linalg` -- Linear algebra operations + +```{eval-rst} +.. automodule:: pytensor.xtensor.linalg + :members: +``` diff --git a/doc/library/xtensor/math.md b/doc/library/xtensor/math.md new file mode 100644 index 0000000000..b87e836b87 --- /dev/null +++ b/doc/library/xtensor/math.md @@ -0,0 +1,8 @@ +(libdoc_xtensor_math)= +# `xtensor.math` Mathematical operations + +```{eval-rst} +.. automodule:: pytensor.xtensor.math + :members: + :exclude-members: XDot, dot +``` \ No newline at end of file diff --git a/doc/library/xtensor/module_functions.md b/doc/library/xtensor/module_functions.md new file mode 100644 index 0000000000..861e969f60 --- /dev/null +++ b/doc/library/xtensor/module_functions.md @@ -0,0 +1,7 @@ +(libdoc_xtensor_module_function)= +# `xtensor` -- Module level operations + +```{eval-rst} +.. automodule:: pytensor.xtensor + :members: broadcast, concat, dot, full_like, ones_like, zeros_like +``` diff --git a/doc/library/xtensor/random.md b/doc/library/xtensor/random.md new file mode 100644 index 0000000000..5be741beca --- /dev/null +++ b/doc/library/xtensor/random.md @@ -0,0 +1,7 @@ +(libdoc_xtensor_random)= +# `xtensor.random` Random number generator operations + +```{eval-rst} +.. automodule:: pytensor.xtensor.random + :members: +``` diff --git a/doc/library/xtensor/type.md b/doc/library/xtensor/type.md new file mode 100644 index 0000000000..083e8ba12c --- /dev/null +++ b/doc/library/xtensor/type.md @@ -0,0 +1,21 @@ +(libdoc_xtenor_type)= + +# `xtensor.type` -- Types and Variables + +## XTensorVariable creation functions + +```{eval-rst} +.. automodule:: pytensor.xtensor.type + :members: xtensor, xtensor_constant, as_xtensor + +``` + +## XTensor Type and Variable classes + +```{eval-rst} +.. automodule:: pytensor.xtensor.type + :noindex: + :members: XTensorType, XTensorVariable, XTensorConstant +``` + + diff --git a/pytensor/xtensor/__init__.py b/pytensor/xtensor/__init__.py index 259701c435..7292bea131 100644 --- a/pytensor/xtensor/__init__.py +++ b/pytensor/xtensor/__init__.py @@ -1,7 +1,7 @@ import warnings import pytensor.xtensor.rewriting -from pytensor.xtensor import linalg, random +from pytensor.xtensor import linalg, math, random from pytensor.xtensor.math import dot from pytensor.xtensor.shape import broadcast, concat, full_like, ones_like, zeros_like from pytensor.xtensor.type import ( diff --git a/pytensor/xtensor/linalg.py b/pytensor/xtensor/linalg.py index b1de9459df..da03dfb086 100644 --- a/pytensor/xtensor/linalg.py +++ b/pytensor/xtensor/linalg.py @@ -11,17 +11,31 @@ def cholesky( lower: bool = True, *, check_finite: bool = False, - overwrite_a: bool = False, on_error: Literal["raise", "nan"] = "raise", dims: Sequence[str], ): + """Compute the Cholesky decomposition of an XTensorVariable. + + Parameters + ---------- + x : XTensorVariable + The input variable to decompose. + lower : bool, optional + Whether to return the lower triangular matrix. Default is True. + check_finite : bool, optional + Whether to check that the input is finite. Default is False. + on_error : {'raise', 'nan'}, optional + What to do if the input is not positive definite. If 'raise', an error is raised. + If 'nan', the output will contain NaNs. Default is 'raise'. + dims : Sequence[str] + The two core dimensions of the input variable, over which the Cholesky decomposition is computed. + """ if len(dims) != 2: raise ValueError(f"Cholesky needs two dims, got {len(dims)}") core_op = Cholesky( lower=lower, check_finite=check_finite, - overwrite_a=overwrite_a, on_error=on_error, ) core_dims = ( @@ -40,6 +54,30 @@ def solve( lower: bool = False, check_finite: bool = False, ): + """Solve a system of linear equations using XTensorVariables. + + Parameters + ---------- + a : XTensorVariable + The left hand-side xtensor. + b : XTensorVariable + The right-hand side xtensor. + dims : Sequence[str] + The core dimensions over which to solve the linear equations. + If length is 2, we are solving a matrix-vector equation, + and the two dimensions should be present in `a`, but only one in `b`. + If length is 3, we are solving a matrix-matrix equation, + and two dimensions should be present in `a`, two in `b`, and only one should be shared. + In both cases the shared dimension will not appear in the output. + assume_a : str, optional + The type of matrix `a` is assumed to be. Default is 'gen' (general). + Options are ["gen", "sym", "her", "pos", "tridiagonal", "banded"]. + Long form options can also be used ["general", "symmetric", "hermitian", "positive_definite"]. + lower : bool, optional + Whether `a` is lower triangular. Default is False. Only relevant if `assume_a` is "sym", "her", or "pos". + check_finite : bool, optional + Whether to check that the input is finite. Default is False. + """ a, b = as_xtensor(a), as_xtensor(b) input_core_dims: tuple[tuple[str, str], tuple[str] | tuple[str, str]] output_core_dims: tuple[tuple[str] | tuple[str, str]] diff --git a/pytensor/xtensor/math.py b/pytensor/xtensor/math.py index 687d7220d7..af453d16e9 100644 --- a/pytensor/xtensor/math.py +++ b/pytensor/xtensor/math.py @@ -1,5 +1,5 @@ import sys -from collections.abc import Iterable +from collections.abc import Iterable, Sequence from types import EllipsisType import numpy as np @@ -7,7 +7,6 @@ import pytensor.scalar as ps from pytensor import config from pytensor.graph.basic import Apply -from pytensor.scalar import ScalarOp from pytensor.scalar.basic import _cast_mapping, upcast from pytensor.xtensor.basic import XOp, as_xtensor from pytensor.xtensor.type import xtensor @@ -17,110 +16,477 @@ this_module = sys.modules[__name__] -def _as_xelemwise(core_op: ScalarOp) -> XElemwise: - out = XElemwise(core_op) - out.__doc__ = f"Ufunc version of {core_op} for XTensorVariables" - return out - - -abs = _as_xelemwise(ps.abs) -add = _as_xelemwise(ps.add) -logical_and = bitwise_and = and_ = _as_xelemwise(ps.and_) -angle = _as_xelemwise(ps.angle) -arccos = _as_xelemwise(ps.arccos) -arccosh = _as_xelemwise(ps.arccosh) -arcsin = _as_xelemwise(ps.arcsin) -arcsinh = _as_xelemwise(ps.arcsinh) -arctan = _as_xelemwise(ps.arctan) -arctan2 = _as_xelemwise(ps.arctan2) -arctanh = _as_xelemwise(ps.arctanh) -betainc = _as_xelemwise(ps.betainc) -betaincinv = _as_xelemwise(ps.betaincinv) -ceil = _as_xelemwise(ps.ceil) -clip = _as_xelemwise(ps.clip) -complex = _as_xelemwise(ps.complex) -conjugate = conj = _as_xelemwise(ps.conj) -cos = _as_xelemwise(ps.cos) -cosh = _as_xelemwise(ps.cosh) -deg2rad = _as_xelemwise(ps.deg2rad) -equal = eq = _as_xelemwise(ps.eq) -erf = _as_xelemwise(ps.erf) -erfc = _as_xelemwise(ps.erfc) -erfcinv = _as_xelemwise(ps.erfcinv) -erfcx = _as_xelemwise(ps.erfcx) -erfinv = _as_xelemwise(ps.erfinv) -exp = _as_xelemwise(ps.exp) -exp2 = _as_xelemwise(ps.exp2) -expm1 = _as_xelemwise(ps.expm1) -floor = _as_xelemwise(ps.floor) -floor_divide = floor_div = int_div = _as_xelemwise(ps.int_div) -gamma = _as_xelemwise(ps.gamma) -gammainc = _as_xelemwise(ps.gammainc) -gammaincc = _as_xelemwise(ps.gammaincc) -gammainccinv = _as_xelemwise(ps.gammainccinv) -gammaincinv = _as_xelemwise(ps.gammaincinv) -gammal = _as_xelemwise(ps.gammal) -gammaln = _as_xelemwise(ps.gammaln) -gammau = _as_xelemwise(ps.gammau) -greater_equal = ge = _as_xelemwise(ps.ge) -greater = gt = _as_xelemwise(ps.gt) -hyp2f1 = _as_xelemwise(ps.hyp2f1) -i0 = _as_xelemwise(ps.i0) -i1 = _as_xelemwise(ps.i1) -identity = _as_xelemwise(ps.identity) -imag = _as_xelemwise(ps.imag) -logical_not = bitwise_invert = bitwise_not = invert = _as_xelemwise(ps.invert) -isinf = _as_xelemwise(ps.isinf) -isnan = _as_xelemwise(ps.isnan) -iv = _as_xelemwise(ps.iv) -ive = _as_xelemwise(ps.ive) -j0 = _as_xelemwise(ps.j0) -j1 = _as_xelemwise(ps.j1) -jv = _as_xelemwise(ps.jv) -kve = _as_xelemwise(ps.kve) -less_equal = le = _as_xelemwise(ps.le) -log = _as_xelemwise(ps.log) -log10 = _as_xelemwise(ps.log10) -log1mexp = _as_xelemwise(ps.log1mexp) -log1p = _as_xelemwise(ps.log1p) -log2 = _as_xelemwise(ps.log2) -less = lt = _as_xelemwise(ps.lt) -mod = _as_xelemwise(ps.mod) -multiply = mul = _as_xelemwise(ps.mul) -negative = neg = _as_xelemwise(ps.neg) -not_equal = neq = _as_xelemwise(ps.neq) -logical_or = bitwise_or = or_ = _as_xelemwise(ps.or_) -owens_t = _as_xelemwise(ps.owens_t) -polygamma = _as_xelemwise(ps.polygamma) -power = pow = _as_xelemwise(ps.pow) -psi = _as_xelemwise(ps.psi) -rad2deg = _as_xelemwise(ps.rad2deg) -real = _as_xelemwise(ps.real) -reciprocal = _as_xelemwise(ps.reciprocal) -round = _as_xelemwise(ps.round_half_to_even) -maximum = _as_xelemwise(ps.scalar_maximum) -minimum = _as_xelemwise(ps.scalar_minimum) -second = _as_xelemwise(ps.second) -sigmoid = expit = _as_xelemwise(ps.sigmoid) -sign = _as_xelemwise(ps.sign) -sin = _as_xelemwise(ps.sin) -sinh = _as_xelemwise(ps.sinh) -softplus = _as_xelemwise(ps.softplus) -square = sqr = _as_xelemwise(ps.sqr) -sqrt = _as_xelemwise(ps.sqrt) -subtract = sub = _as_xelemwise(ps.sub) -where = switch = _as_xelemwise(ps.switch) -tan = _as_xelemwise(ps.tan) -tanh = _as_xelemwise(ps.tanh) -tri_gamma = _as_xelemwise(ps.tri_gamma) -true_divide = true_div = _as_xelemwise(ps.true_div) -trunc = _as_xelemwise(ps.trunc) -logical_xor = bitwise_xor = xor = _as_xelemwise(ps.xor) +def _as_xelemwise(core_op): + x_op = XElemwise(core_op) + + def decorator(func): + def wrapper(*args, **kwargs): + return x_op(*args, **kwargs) + + wrapper.__doc__ = f"Ufunc version of {core_op} for XTensorVariables" + return wrapper + + return decorator + + +@_as_xelemwise(ps.abs) +def abs(): ... + + +@_as_xelemwise(ps.add) +def add(): ... + + +@_as_xelemwise(ps.and_) +def logical_and(): ... + + +@_as_xelemwise(ps.and_) +def bitwise_and(): ... + + +and_ = logical_and + + +@_as_xelemwise(ps.angle) +def angle(): ... + + +@_as_xelemwise(ps.arccos) +def arccos(): ... + + +@_as_xelemwise(ps.arccosh) +def arccosh(): ... + + +@_as_xelemwise(ps.arcsin) +def arcsin(): ... + + +@_as_xelemwise(ps.arcsinh) +def arcsinh(): ... + + +@_as_xelemwise(ps.arctan) +def arctan(): ... + + +@_as_xelemwise(ps.arctan2) +def arctan2(): ... + + +@_as_xelemwise(ps.arctanh) +def arctanh(): ... + + +@_as_xelemwise(ps.betainc) +def betainc(): ... + + +@_as_xelemwise(ps.betaincinv) +def betaincinv(): ... + + +@_as_xelemwise(ps.ceil) +def ceil(): ... + + +@_as_xelemwise(ps.clip) +def clip(): ... + + +@_as_xelemwise(ps.complex) +def complex(): ... + + +@_as_xelemwise(ps.conj) +def conjugate(): ... + + +conj = conjugate + + +@_as_xelemwise(ps.cos) +def cos(): ... + + +@_as_xelemwise(ps.cosh) +def cosh(): ... + + +@_as_xelemwise(ps.deg2rad) +def deg2rad(): ... + + +@_as_xelemwise(ps.eq) +def equal(): ... + + +eq = equal + + +@_as_xelemwise(ps.erf) +def erf(): ... + + +@_as_xelemwise(ps.erfc) +def erfc(): ... + + +@_as_xelemwise(ps.erfcinv) +def erfcinv(): ... + + +@_as_xelemwise(ps.erfcx) +def erfcx(): ... + + +@_as_xelemwise(ps.erfinv) +def erfinv(): ... + + +@_as_xelemwise(ps.exp) +def exp(): ... + + +@_as_xelemwise(ps.exp2) +def exp2(): ... + + +@_as_xelemwise(ps.expm1) +def expm1(): ... + + +@_as_xelemwise(ps.floor) +def floor(): ... + + +@_as_xelemwise(ps.int_div) +def floor_divide(): ... + + +floor_div = int_div = floor_divide + + +@_as_xelemwise(ps.gamma) +def gamma(): ... + + +@_as_xelemwise(ps.gammainc) +def gammainc(): ... + + +@_as_xelemwise(ps.gammaincc) +def gammaincc(): ... + + +@_as_xelemwise(ps.gammainccinv) +def gammainccinv(): ... + + +@_as_xelemwise(ps.gammaincinv) +def gammaincinv(): ... + + +@_as_xelemwise(ps.gammal) +def gammal(): ... + + +@_as_xelemwise(ps.gammaln) +def gammaln(): ... + + +@_as_xelemwise(ps.gammau) +def gammau(): ... + + +@_as_xelemwise(ps.ge) +def greater_equal(): ... + + +ge = greater_equal + + +@_as_xelemwise(ps.gt) +def greater(): ... + + +gt = greater + + +@_as_xelemwise(ps.hyp2f1) +def hyp2f1(): ... + + +@_as_xelemwise(ps.i0) +def i0(): ... + + +@_as_xelemwise(ps.i1) +def i1(): ... + + +@_as_xelemwise(ps.identity) +def identity(): ... + + +@_as_xelemwise(ps.imag) +def imag(): ... + + +@_as_xelemwise(ps.invert) +def logical_not(): ... + + +@_as_xelemwise(ps.invert) +def bitwise_not(): ... + + +@_as_xelemwise(ps.invert) +def bitwise_invert(): ... + + +@_as_xelemwise(ps.invert) +def invert(): ... + + +@_as_xelemwise(ps.isinf) +def isinf(): ... + + +@_as_xelemwise(ps.isnan) +def isnan(): ... + + +@_as_xelemwise(ps.iv) +def iv(): ... + + +@_as_xelemwise(ps.ive) +def ive(): ... + + +@_as_xelemwise(ps.j0) +def j0(): ... + + +@_as_xelemwise(ps.j1) +def j1(): ... + + +@_as_xelemwise(ps.jv) +def jv(): ... + + +@_as_xelemwise(ps.kve) +def kve(): ... + + +@_as_xelemwise(ps.le) +def less_equal(): ... + + +le = less_equal + + +@_as_xelemwise(ps.log) +def log(): ... + + +@_as_xelemwise(ps.log10) +def log10(): ... + + +@_as_xelemwise(ps.log1mexp) +def log1mexp(): ... + + +@_as_xelemwise(ps.log1p) +def log1p(): ... + + +@_as_xelemwise(ps.log2) +def log2(): ... + + +@_as_xelemwise(ps.lt) +def less(): ... + + +lt = less + + +@_as_xelemwise(ps.mod) +def mod(): ... + + +@_as_xelemwise(ps.mul) +def multiply(): ... + + +mul = multiply + + +@_as_xelemwise(ps.neg) +def negative(): ... + + +neg = negative + + +@_as_xelemwise(ps.neq) +def not_equal(): ... + + +neq = not_equal + + +@_as_xelemwise(ps.or_) +def logical_or(): ... + + +@_as_xelemwise(ps.or_) +def bitwise_or(): ... + + +or_ = logical_or + + +@_as_xelemwise(ps.owens_t) +def owens_t(): ... + + +@_as_xelemwise(ps.polygamma) +def polygamma(): ... + + +@_as_xelemwise(ps.pow) +def power(): ... + + +pow = power + + +@_as_xelemwise(ps.psi) +def psi(): ... + + +@_as_xelemwise(ps.rad2deg) +def rad2deg(): ... + + +@_as_xelemwise(ps.real) +def real(): ... + + +@_as_xelemwise(ps.reciprocal) +def reciprocal(): ... + + +@_as_xelemwise(ps.round_half_to_even) +def round(): ... + + +@_as_xelemwise(ps.scalar_maximum) +def maximum(): ... + + +@_as_xelemwise(ps.scalar_minimum) +def minimum(): ... + + +@_as_xelemwise(ps.second) +def second(): ... + + +@_as_xelemwise(ps.sigmoid) +def sigmoid(): ... + + +expit = sigmoid + + +@_as_xelemwise(ps.sign) +def sign(): ... + + +@_as_xelemwise(ps.sin) +def sin(): ... + + +@_as_xelemwise(ps.sinh) +def sinh(): ... + + +@_as_xelemwise(ps.softplus) +def softplus(): ... + + +@_as_xelemwise(ps.sqr) +def square(): ... + + +sqr = square + + +@_as_xelemwise(ps.sqrt) +def sqrt(): ... + + +@_as_xelemwise(ps.sub) +def subtract(): ... + + +sub = subtract + + +@_as_xelemwise(ps.switch) +def where(): ... + + +switch = where + + +@_as_xelemwise(ps.tan) +def tan(): ... + + +@_as_xelemwise(ps.tanh) +def tanh(): ... + + +@_as_xelemwise(ps.tri_gamma) +def tri_gamma(): ... + + +@_as_xelemwise(ps.true_div) +def true_divide(): ... + + +true_div = true_divide + + +@_as_xelemwise(ps.trunc) +def trunc(): ... + + +@_as_xelemwise(ps.xor) +def logical_xor(): ... + + +@_as_xelemwise(ps.xor) +def bitwise_xor(): ... + + +xor = logical_xor + _xelemwise_cast_op: dict[str, XElemwise] = {} def cast(x, dtype): + """Cast an XTensorVariable to a different dtype.""" if dtype == "floatX": dtype = config.floatX else: @@ -141,6 +507,7 @@ def cast(x, dtype): def softmax(x, dim=None): + """Compute the softmax of an XTensorVariable along a specified dimension.""" exp_x = exp(x) return exp_x / exp_x.sum(dim=dim) @@ -195,11 +562,11 @@ def make_node(self, x, y): return Apply(self, [x, y], [out]) -def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): - """Matrix multiplication between two XTensorVariables. +def dot(x, y, dim: str | Sequence[str] | EllipsisType | None = None): + """Generalized dot product for XTensorVariables. - This operation performs matrix multiplication between two tensors, automatically - aligning and contracting dimensions. The behavior matches xarray's dot operation. + This operation performs multiplication followed by summation for shared dimensions + or simply summation for non-shared dimensions. Parameters ---------- @@ -207,21 +574,29 @@ def dot(x, y, dim: str | Iterable[str] | EllipsisType | None = None): First input tensor y : XTensorVariable Second input tensor - dim : str, Iterable[Hashable], EllipsisType, or None, optional + dim : str, Sequence[str], Ellipsis (...), or None, optional The dimensions to contract over. If None, will contract over all matching dimensions. If Ellipsis (...), will contract over all dimensions. Returns ------- XTensorVariable - The result of the matrix multiplication. + Examples -------- - >>> x = xtensor(dtype="float64", dims=("a", "b"), shape=(2, 3)) - >>> y = xtensor(dtype="float64", dims=("b", "c"), shape=(3, 4)) - >>> z = dot(x, y) # Result has dimensions ("a", "c") - >>> z = dot(x, y, dim=...) # Contract over all dimensions + + .. testcode:: + + from pytensor.xtensor import xtensor, dot + + x = xtensor("x", dims=("a", "b")) + y = xtensor("y", dims=("b", "c")) + + assert dot(x, y).dims == ("a", "c") # Contract over shared `b` dimension + assert dot(x, y, dim=("a", "b")).dims == ("c",) # Contract over 'a' and 'b' + assert dot(x, y, dim=...).dims == () # Contract over all dimensions + """ x = as_xtensor(x) y = as_xtensor(y) diff --git a/pytensor/xtensor/random.py b/pytensor/xtensor/random.py index 8f24ae24e1..98d9bb96df 100644 --- a/pytensor/xtensor/random.py +++ b/pytensor/xtensor/random.py @@ -10,7 +10,7 @@ from pytensor.xtensor.vectorization import XRV -def _as_xrv( +def as_xrv( core_op: RandomVariable, core_inps_dims_map: Sequence[Sequence[int]] | None = None, core_out_dims_map: Sequence[int] | None = None, @@ -52,7 +52,6 @@ def _as_xrv( max((entry + 1 for entry in core_out_dims_map), default=0), ) - @wraps(core_op) def xrv_constructor( *params, core_dims: Sequence[str] | str | None = None, @@ -93,38 +92,151 @@ def xrv_constructor( return xrv_constructor -bernoulli = _as_xrv(ptr.bernoulli) -beta = _as_xrv(ptr.beta) -betabinom = _as_xrv(ptr.betabinom) -binomial = _as_xrv(ptr.binomial) -categorical = _as_xrv(ptr.categorical) -cauchy = _as_xrv(ptr.cauchy) -dirichlet = _as_xrv(ptr.dirichlet) -exponential = _as_xrv(ptr.exponential) -gamma = _as_xrv(ptr._gamma) -gengamma = _as_xrv(ptr.gengamma) -geometric = _as_xrv(ptr.geometric) -gumbel = _as_xrv(ptr.gumbel) -halfcauchy = _as_xrv(ptr.halfcauchy) -halfnormal = _as_xrv(ptr.halfnormal) -hypergeometric = _as_xrv(ptr.hypergeometric) -integers = _as_xrv(ptr.integers) -invgamma = _as_xrv(ptr.invgamma) -laplace = _as_xrv(ptr.laplace) -logistic = _as_xrv(ptr.logistic) -lognormal = _as_xrv(ptr.lognormal) -multinomial = _as_xrv(ptr.multinomial) -nbinom = negative_binomial = _as_xrv(ptr.negative_binomial) -normal = _as_xrv(ptr.normal) -pareto = _as_xrv(ptr.pareto) -poisson = _as_xrv(ptr.poisson) -t = _as_xrv(ptr.t) -triangular = _as_xrv(ptr.triangular) -truncexpon = _as_xrv(ptr.truncexpon) -uniform = _as_xrv(ptr.uniform) -vonmises = _as_xrv(ptr.vonmises) -wald = _as_xrv(ptr.wald) -weibull = _as_xrv(ptr.weibull) +def _as_xrv(core_op: RandomVariable, name: str | None = None): + """A decorator to create a new XRV and document it in sphinx.""" + xrv_constructor = as_xrv(core_op, name=name) + + def decorator(func): + @wraps(as_xrv) + def wrapper(*args, **kwargs): + return xrv_constructor(*args, **kwargs) + + wrapper.__doc__ = f"XRV version of {core_op.name} for XTensorVariables" + + return wrapper + + return decorator + + +@_as_xrv(ptr.bernoulli) +def bernoulli(): ... + + +@_as_xrv(ptr.beta) +def beta(): ... + + +@_as_xrv(ptr.betabinom) +def betabinom(): ... + + +@_as_xrv(ptr.binomial) +def binomial(): ... + + +@_as_xrv(ptr.categorical) +def categorical(): ... + + +@_as_xrv(ptr.cauchy) +def cauchy(): ... + + +@_as_xrv(ptr.dirichlet) +def dirichlet(): ... + + +@_as_xrv(ptr.exponential) +def exponential(): ... + + +@_as_xrv(ptr._gamma) +def gamma(): ... + + +@_as_xrv(ptr.gengamma) +def gengamma(): ... + + +@_as_xrv(ptr.geometric) +def geometric(): ... + + +@_as_xrv(ptr.gumbel) +def gumbel(): ... + + +@_as_xrv(ptr.halfcauchy) +def halfcauchy(): ... + + +@_as_xrv(ptr.halfnormal) +def halfnormal(): ... + + +@_as_xrv(ptr.hypergeometric) +def hypergeometric(): ... + + +@_as_xrv(ptr.integers) +def integers(): ... + + +@_as_xrv(ptr.invgamma) +def invgamma(): ... + + +@_as_xrv(ptr.laplace) +def laplace(): ... + + +@_as_xrv(ptr.logistic) +def logistic(): ... + + +@_as_xrv(ptr.lognormal) +def lognormal(): ... + + +@_as_xrv(ptr.multinomial) +def multinomial(): ... + + +@_as_xrv(ptr.negative_binomial) +def negative_binomial(): ... + + +nbinom = negative_binomial + + +@_as_xrv(ptr.normal) +def normal(): ... + + +@_as_xrv(ptr.pareto) +def pareto(): ... + + +@_as_xrv(ptr.poisson) +def poisson(): ... + + +@_as_xrv(ptr.t) +def t(): ... + + +@_as_xrv(ptr.triangular) +def triangular(): ... + + +@_as_xrv(ptr.truncexpon) +def truncexpon(): ... + + +@_as_xrv(ptr.uniform) +def uniform(): ... + + +@_as_xrv(ptr.vonmises) +def vonmises(): ... + + +@_as_xrv(ptr.wald) +def wald(): ... + + +@_as_xrv(ptr.weibull) +def weibull(): ... def multivariate_normal( @@ -136,6 +248,7 @@ def multivariate_normal( rng=None, method: Literal["cholesky", "svd", "eigh"] = "cholesky", ): + """Multivariate normal random variable.""" mean = as_xtensor(mean) if len(core_dims) != 2: raise ValueError( @@ -147,7 +260,7 @@ def multivariate_normal( if core_dims[0] not in mean.type.dims: core_dims = core_dims[::-1] - xop = _as_xrv(ptr.MvNormalRV(method=method)) + xop = as_xrv(ptr.MvNormalRV(method=method)) return xop(mean, cov, core_dims=core_dims, extra_dims=extra_dims, rng=rng) diff --git a/pytensor/xtensor/readme.md b/pytensor/xtensor/readme.md deleted file mode 100644 index b3511f56ad..0000000000 --- a/pytensor/xtensor/readme.md +++ /dev/null @@ -1,69 +0,0 @@ -# XTensor Module - -This module implements as abstraction layer on regular tensor operations, that behaves like Xarray. - -A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute, -that labels the dimensions of the tensor. - -Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects. - -The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations. -These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into -a regular tensor graph that can itself be evaluated as usual. - -Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray. -If the existing XOps can be composed to produce the desired result, then we can use them directly. - -## Coordinates -For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`. -The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor. -Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor. - -## Example - -```python -import pytensor.tensor as pt -import pytensor.xtensor as px - -a = pt.tensor("a", shape=(3,)) -b = pt.tensor("b", shape=(4,)) - -ax = px.as_xtensor(a, dims=["x"]) -bx = px.as_xtensor(b, dims=["y"]) - -zx = ax + bx -assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4)) - -z = zx.values -z.dprint() -# TensorFromXTensor [id A] -# └─ XElemwise{scalar_op=Add()} [id B] -# ├─ XTensorFromTensor{dims=('x',)} [id C] -# │ └─ a [id D] -# └─ XTensorFromTensor{dims=('y',)} [id E] -# └─ b [id F] -``` - -Once we compile the graph, no `XOp`s are left. - -```python -import pytensor - -with pytensor.config.change_flags(optimizer_verbose=True): - fn = pytensor.function([a, b], z) - -# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0) -# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None -# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None -# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0) - -fn.dprint() -# Add [id A] 2 -# ├─ ExpandDims{axis=1} [id B] 1 -# │ └─ a [id C] -# └─ ExpandDims{axis=0} [id D] 0 -# └─ b [id E] -``` - - - diff --git a/pytensor/xtensor/shape.py b/pytensor/xtensor/shape.py index a4b1491f71..3e2116e56b 100644 --- a/pytensor/xtensor/shape.py +++ b/pytensor/xtensor/shape.py @@ -303,6 +303,35 @@ def make_node(self, *inputs): def concat(xtensors, dim: str): + """Concatenate a sequence of XTensorVariables along a specified dimension. + + Parameters + ---------- + xtensors : Sequence of XTensorVariable + The tensors to concatenate. + dim : str + The dimension along which to concatenate the tensors. + + Returns + ------- + XTensorVariable + + + Example + ------- + + .. testcode:: + + from pytensor.xtensor import as_xtensor, xtensor, concat + + x = xtensor("x", shape=(2, 3), dims=("a", "b")) + zero = as_xtensor([0], dims=("a")) + + out = concat([zero, x, zero], dim="a") + assert out.type.dims == ("a", "b") + assert out.type.shape == (4, 3) + + """ return Concat(dim=dim)(*xtensors) diff --git a/pytensor/xtensor/type.py b/pytensor/xtensor/type.py index 0c8ca0914e..1e16912eaa 100644 --- a/pytensor/xtensor/type.py +++ b/pytensor/xtensor/type.py @@ -201,6 +201,24 @@ def xtensor( shape: Sequence[int | None] | None = None, dtype: str | np.dtype = "floatX", ): + """Create an XTensorVariable. + + Parameters + ---------- + name : str or None, optional + The name of the variable + dims : Sequence[str] + The names of the dimensions of the tensor + shape : Sequence[int | None] or None, optional + The shape of the tensor. If None, defaults to a shape with None for each dimension. + dtype : str or np.dtype, optional + The data type of the tensor. Defaults to 'floatX' (config.floatX). + + Returns + ------- + XTensorVariable + A new XTensorVariable with the specified name, dims, shape, and dtype. + """ return XTensorType(dtype=dtype, dims=dims, shape=shape)(name=name) @@ -208,6 +226,8 @@ def xtensor( class XTensorVariable(Variable[_XTensorTypeType, OptionalApplyType]): + """Variable of XTensorType.""" + # These can't work because Python requires native output types def __bool__(self): raise TypeError( @@ -406,7 +426,7 @@ def rename(self, new_name_or_name_dict=None, **names): def copy(self, name: str | None = None): out = px.math.identity(self) - out.name = name # type: ignore + out.name = name return out def astype(self, dtype): @@ -751,6 +771,8 @@ class XTensorConstantSignature(TensorConstantSignature): class XTensorConstant(XTensorVariable, Constant[_XTensorTypeType]): + """Constant of XtensorType.""" + def __init__(self, type: _XTensorTypeType, data, name=None): data_shape = np.shape(data) @@ -776,6 +798,8 @@ def signature(self): def xtensor_constant(x, name=None, dims: None | Sequence[str] = None): + """Convert a constant value to an XTensorConstant.""" + x_dims: tuple[str, ...] if XARRAY_AVAILABLE and isinstance(x, xr.DataArray): xarray_dims = x.dims @@ -819,7 +843,20 @@ def as_symbolic_xarray(x, **kwargs): return xtensor_constant(x, **kwargs) -def as_xtensor(x, name=None, dims: Sequence[str] | None = None): +def as_xtensor(x, dims: Sequence[str] | None = None, *, name: str | None = None): + """Convert a variable or data to an XTensorVariable. + + Parameters + ---------- + x : Variable or data + dims: Sequence[str] or None, optional + If dims are provided, TensorVariable (or data) will be converted to an XTensorVariable with those dims. + XTensorVariables will be returned as is, if the dims match. Otherwise, a ValueError is raised. + If dims are not provided, and the data is not a scalar, an XTensorVariable or xarray.DataArray, an error is raised. + name: str or None, optional + Name of the resulting XTensorVariable. + """ + if isinstance(x, Apply): if len(x.outputs) != 1: raise ValueError(