Skip to content

Add xtensor broadcast #1489

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 22 commits into
base: labeled_tensors
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
22 commits
Select commit Hold shift + click to select a range
4fb9071
Avoid no-op DimShuffle
ricardoV94 Jun 20, 2025
5b39df6
Use DimShuffle instead of Reshape in `ix_`
ricardoV94 May 22, 2025
7a7db6f
Extract ViewOp functionality into a base TypeCastOp
ricardoV94 Jun 20, 2025
024136e
Implement basic labeled tensor functionality
ricardoV94 Aug 2, 2023
162b50a
Implement stack for XTensorVariables
ricardoV94 Jun 6, 2025
9676b4e
Implement Elemwise and Blockwise operations for XTensorVariables
ricardoV94 May 26, 2025
1a0226c
Implement cast for XTensorVariables
ricardoV94 Jun 6, 2025
b1b5fde
Implement reduction operations for XTensorVariables
ricardoV94 May 25, 2025
3e4f7ae
Implement concat for XTensorVariables
ricardoV94 May 26, 2025
62d410b
Implement transpose for XTensorVariables
AllenDowney May 28, 2025
bd29b1b
Implement unstack for XTensorVariables
OriolAbril May 22, 2025
e6da6c5
Implement index for XTensorVariables
ricardoV94 May 21, 2025
e2a87db
Implement index update for XTensorVariables
ricardoV94 Jun 2, 2025
fc5f668
Implement diff for XTensorVariables
ricardoV94 May 26, 2025
4235d63
Implement squeeze for XTensorVariables
AllenDowney Jun 6, 2025
7a9db22
Implement expand_dims for XTensorVariables (#1449)
AllenDowney Jun 13, 2025
cc28cb0
Implement dot for XTensorVariables (#1475)
AllenDowney Jun 19, 2025
0be448e
Implement XTensorVariable version of RandomVariables
ricardoV94 Jun 20, 2025
2a91f58
Add implementation of broadcast for xtensor
AllenDowney Jun 20, 2025
4f02b39
Add xtensor broadcast
AllenDowney Jun 20, 2025
ea04c9e
Handling symbolic dims
AllenDowney Jun 21, 2025
9a255b6
Adding broadcast_like
AllenDowney Jun 21, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions .github/workflows/test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
Expand Down Expand Up @@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
Expand Down Expand Up @@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
Expand Down Expand Up @@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx

pip install -e ./
Expand All @@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}

- name: Run tests
Expand Down
23 changes: 11 additions & 12 deletions pytensor/compile/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,8 @@ def register_view_op_c_code(type, code, version=()):
ViewOp.c_code_and_version[type] = (code, version)


class ViewOp(COp):
"""
Returns an inplace view of the input. Used internally by PyTensor.

"""
class TypeCastingOp(COp):
"""Op that performs a graph-level type cast operation, but has no effect computation-wise (identity function)."""

view_map = {0: [0]}
# Mapping from Type to C code (and version) to use.
Expand All @@ -47,13 +44,8 @@ class ViewOp(COp):
__props__: tuple = ()
_f16_ok: bool = True

def make_node(self, x):
return Apply(self, [x], [x.type()])

def perform(self, node, inp, out):
(x,) = inp
(z,) = out
z[0] = x
def perform(self, node, inputs, outputs_storage):
outputs_storage[0][0] = inputs[0]

def __str__(self):
return f"{self.__class__.__name__}"
Expand Down Expand Up @@ -97,6 +89,13 @@ def grad(self, args, g_outs):
return g_outs


class ViewOp(TypeCastingOp):
"""Returns an inplace view of the input. Used internally by PyTensor."""

def make_node(self, x):
return Apply(self, [x], [x.type()])


view_op = ViewOp()


Expand Down
10 changes: 5 additions & 5 deletions pytensor/link/jax/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

from pytensor.compile import JAX
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.ops import DeepCopyOp, ViewOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
Expand Down Expand Up @@ -111,12 +111,12 @@ def deepcopyop(x):
return deepcopyop


@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op, **kwargs):
def viewop(x):
@jax_funcify.register(TypeCastingOp)
def jax_funcify_TypeCastingOp(op, **kwargs):
def type_cast(x):
return x

return viewop
return type_cast


@jax_funcify.register(OpFromGraph)
Expand Down
10 changes: 5 additions & 5 deletions pytensor/link/numba/dispatch/scalar.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from pytensor.compile.ops import ViewOp
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph.basic import Variable
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
Expand Down Expand Up @@ -198,14 +198,14 @@ def cast(x):


@numba_basic.numba_njit
def viewop(x):
def identity(x):
return x


@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
return numba_basic.global_numba_func(viewop)
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
return numba_basic.global_numba_func(identity)


@numba_basic.numba_njit
Expand Down
10 changes: 9 additions & 1 deletion pytensor/link/pytorch/dispatch/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from pytensor.compile import PYTORCH
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
Expand Down Expand Up @@ -71,6 +71,14 @@ def pytorch_funcify_FunctionGraph(
)


@pytorch_funcify.register(TypeCastingOp)
def pytorch_funcify_CastingOp(op, node, **kwargs):
def type_cast(x):
return x

return type_cast


@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
Expand Down
2 changes: 1 addition & 1 deletion pytensor/tensor/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -4551,7 +4551,7 @@ def ix_(*args):
new = as_tensor(new)
if new.ndim != 1:
raise ValueError("Cross index must be 1 dimensional")
new = new.reshape((1,) * k + (new.size,) + (1,) * (nd - k - 1))
new = new.dimshuffle(*(("x",) * k), 0, *(("x",) * (nd - k - 1)))
out.append(new)
return tuple(out)

Expand Down
18 changes: 0 additions & 18 deletions pytensor/tensor/extra_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,24 +473,6 @@ def cumprod(x, axis=None):
return CumOp(axis=axis, mode="mul")(x)


class CumsumOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "add"
return obj


class CumprodOp(Op):
__props__ = ("axis",)

def __new__(typ, *args, **kwargs):
obj = object.__new__(CumOp, *args, **kwargs)
obj.mode = "mul"
return obj


def diff(x, n=1, axis=-1):
"""Calculate the `n`-th order discrete difference along the given `axis`.

Expand Down
4 changes: 2 additions & 2 deletions pytensor/tensor/random/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -1625,8 +1625,7 @@ def rng_fn_scipy(cls, rng, n, p, size):
return stats.nbinom.rvs(n, p, size=size, random_state=rng)


nbinom = NegBinomialRV()
negative_binomial = NegBinomialRV()
nbinom = negative_binomial = NegBinomialRV()


class BetaBinomialRV(ScipyRandomVariable):
Expand Down Expand Up @@ -1808,6 +1807,7 @@ def rng_fn(cls, rng, n, p, size):

multinomial = MultinomialRV()


vsearchsorted = np.vectorize(np.searchsorted, otypes=[int], signature="(n),()->()")


Expand Down
19 changes: 18 additions & 1 deletion pytensor/tensor/random/op.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import abc
import warnings
from collections.abc import Sequence
from copy import deepcopy
Expand Down Expand Up @@ -32,7 +33,20 @@
from pytensor.tensor.variable import TensorVariable


class RandomVariable(Op):
class RNGConsumerOp(Op):
"""Baseclass for Ops that consume RNGs."""

@abc.abstractmethod
def update(self, node: Apply) -> dict[Variable, Variable]:
"""Symbolic update expression for input RNG variables.

Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass


class RandomVariable(RNGConsumerOp):
"""An `Op` that produces a sample from a random variable.

This is essentially `RandomFunction`, except that it removes the
Expand Down Expand Up @@ -123,6 +137,9 @@ def __init__(
if self.inplace:
self.destroy_map = {0: [0]}

def update(self, node: Apply) -> dict[Variable, Variable]:
return {node.inputs[0]: node.outputs[0]}

def _supp_shape_from_params(self, dist_params, param_shapes=None):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.

Expand Down
4 changes: 1 addition & 3 deletions pytensor/tensor/rewriting/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return [new_var]


@register_infer_shape
@node_rewriter([Assert])
def local_remove_all_assert(fgraph, node):
r"""A rewrite that removes all `Assert`\s from a graph.
Expand All @@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.

"""
if not isinstance(node.op, Assert):
return

return [node.inputs[0]]


Expand Down
29 changes: 29 additions & 0 deletions pytensor/tensor/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import pytensor
from pytensor.graph import FunctionGraph, Variable
from pytensor.npy_2_compat import normalize_axis_tuple
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.utils import hash_from_code


Expand Down Expand Up @@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))


def get_static_shape_from_size_variables(
size_vars: Sequence[Variable],
) -> tuple[int | None, ...]:
"""Get static shape from size variables.

Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from pytensor.tensor.basic import get_scalar_constant_value

static_lengths: list[None | int] = [None] * len(size_vars)
for i, length in enumerate(size_vars):
try:
static_length = get_scalar_constant_value(length)
except NotScalarConstantError:
pass
else:
static_lengths[i] = int(static_length)
return tuple(static_lengths)
3 changes: 3 additions & 0 deletions pytensor/tensor/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,6 +349,9 @@ def dimshuffle(self, *pattern):
if (len(pattern) == 1) and (isinstance(pattern[0], list | tuple | np.ndarray)):
pattern = pattern[0]
ds_op = pt.elemwise.DimShuffle(input_ndim=self.type.ndim, new_order=pattern)
if ds_op.new_order == tuple(range(self.type.ndim)):
# No-op
return self
return ds_op(self)

def flatten(self, ndim=1):
Expand Down
14 changes: 14 additions & 0 deletions pytensor/xtensor/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import warnings

import pytensor.xtensor.rewriting
from pytensor.xtensor import linalg
from pytensor.xtensor.math import dot
from pytensor.xtensor.shape import broadcast, concat
from pytensor.xtensor.type import (
as_xtensor,
xtensor,
xtensor_constant,
)


warnings.warn("xtensor module is experimental and full of bugs")
Loading