Skip to content

MLX backend POC #1365

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 74 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
d25f214
mlx poc
williambdean Apr 11, 2025
edacc0e
add test for dot
williambdean Apr 11, 2025
052fdc2
restore pytorch
williambdean Apr 11, 2025
a9ecad0
wrap in mx.array
williambdean Apr 11, 2025
e690bff
modify the pytorch jit
williambdean Apr 11, 2025
ad29c17
move file
williambdean Apr 11, 2025
ba29b37
dont wrap
williambdean Apr 11, 2025
8716870
attempt to fix github action
williambdean Apr 11, 2025
9bf7edf
change the rtol
williambdean Apr 11, 2025
96ba116
add init file
williambdean Apr 11, 2025
e116fa1
skip if not installed
williambdean Apr 11, 2025
5d5f754
remove torch related code / comments
williambdean Apr 11, 2025
b8cee3f
simplify the fgraph_convert
williambdean Apr 12, 2025
d057453
assert type
williambdean Apr 12, 2025
ae202e6
simplify the internal
williambdean Apr 18, 2025
f1941fe
remove the language
williambdean Apr 18, 2025
7c8eae7
Adding operations in pytensor
cetagostini Apr 18, 2025
67a74fb
add extension
williambdean Apr 18, 2025
fb5eb52
make compare function
williambdean Apr 18, 2025
516b595
rename function
williambdean Apr 18, 2025
67bb8da
correct the function name
williambdean Apr 18, 2025
82bb964
tests for elemwise
williambdean Apr 18, 2025
877d79f
Changes
cetagostini Apr 18, 2025
fafedd6
Toma tu tomate William
cetagostini Apr 18, 2025
60acb8d
Pushing changes with the core shit.
cetagostini Apr 18, 2025
242aba7
add more tests
williambdean Apr 18, 2025
6cb47fc
additional tests
williambdean Apr 18, 2025
bc98e09
test for switch with mlx
williambdean Apr 18, 2025
4d5b34b
Pushing code
cetagostini Apr 18, 2025
5abd32d
Changes
cetagostini Apr 18, 2025
12daeac
A lot of new code
cetagostini Apr 18, 2025
ac93949
almost there baby william
cetagostini Apr 18, 2025
a19cbc8
Another push small
cetagostini Apr 18, 2025
5c97bc8
fix for all
williambdean Apr 18, 2025
2fc81bc
fix for carlos
williambdean Apr 18, 2025
e6437cc
just return the compiled func
williambdean Apr 19, 2025
c3a3e1a
A change for willy may!
cetagostini Apr 19, 2025
e7cf10e
FINALLY BABY LETS PARTY! (IF YOU ARE READING THIS MAKE MORE PRs)
cetagostini Apr 19, 2025
880dd5c
refactor to use getattr
williambdean Apr 19, 2025
1e6addd
bring argmax test
williambdean Apr 19, 2025
aabbb78
use deepcopy
williambdean Apr 19, 2025
0812c55
move some tests
williambdean Apr 19, 2025
294c271
THE SUPER BLOCKWISEE YA YA YA YA JUUUUU
cetagostini Apr 19, 2025
9d3eca8
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor in…
cetagostini Apr 19, 2025
9f31ab1
Guys, I'm getting sad. We need help yisus!!!!!
cetagostini Apr 19, 2025
37440ff
WILLIAM YOU NEED TO GO ANOTHER MILE! GO ON MY MATEEEEEEE, GO PHILLIES!
cetagostini Apr 19, 2025
4e4923f
RETURN, WHAT A SHAME! Sad times are coming.
cetagostini Apr 19, 2025
6b27dc4
AI COULD BE COOL? OR WE ARE JUST FUCKING AROUND?
cetagostini Apr 19, 2025
e308f83
AI RULES BABY MY MATE
cetagostini Apr 19, 2025
3744a18
test conv1d case
williambdean Apr 19, 2025
b41cab0
I'm going for pizzas, it was an incredible day!
cetagostini Apr 19, 2025
323fa9d
Merge branch 'mlx-poc' of https://github.com/williambdean/pytensor in…
cetagostini Apr 19, 2025
9766975
SUUUUUUUUU!!!!!! LIFE IS GOING WELL. MLX FOR MEDIA MIX MODELS BAY
cetagostini Apr 19, 2025
5ffc5ef
pre-commit
cetagostini Apr 19, 2025
597f84e
Almost working
cetagostini Apr 19, 2025
fb8fd2f
Last PR sampling working
cetagostini Apr 23, 2025
6a2b774
Requested changes by Ricardo
cetagostini Jun 2, 2025
602f0ed
Pre commit changes
cetagostini Jun 2, 2025
8a2aea9
More changes from Ricardo
cetagostini Jun 8, 2025
845561c
Pre Commit RUN
cetagostini Jun 8, 2025
6ab7428
Adding more operations for complex model
cetagostini Jun 8, 2025
b6292f1
Working with simple model
cetagostini Jun 9, 2025
f2d9d1b
Change bad name
cetagostini Jun 9, 2025
5c759b9
Correcting test by Ricardo
cetagostini Jun 9, 2025
97b2e31
Changing synth test
cetagostini Jun 9, 2025
dd83e0f
Optimizing reshape
cetagostini Jun 9, 2025
662b4f2
Comment
cetagostini Jun 9, 2025
bcf7f8d
Small changes and adding small benchmark
cetagostini Jun 9, 2025
e706171
Changes with Ricardo
cetagostini Jun 10, 2025
929630b
improving benchmark
cetagostini Jun 10, 2025
8f2982d
pre commit
cetagostini Jun 10, 2025
02ed254
benchs
cetagostini Jun 10, 2025
06ccf91
Merge branch 'main' into mlx-poc
cetagostini Jul 11, 2025
03a2094
Changes on the branch
cetagostini Jul 11, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pytensor/link/mlx/dispatch/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def tensor_from_scalar(x):
@mlx_funcify.register(ScalarFromTensor)
def mlx_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x):
return x.reshape(-1)[0]
return mx.array(x).reshape(-1)[0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems convoluted. Does MLX have something like x.item()?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I need to check this again. One minute.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It has but we have issues with the pytensor-mlx way to compile, and this make it generic enough.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Optimize to have both options, hope you like that more!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!


return scalar_from_tensor

Expand Down
54 changes: 52 additions & 2 deletions pytensor/link/mlx/dispatch/math.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,9 @@
Cast,
Cos,
Exp,
IntDiv,
Invert,
IsNan,
Log,
Log1p,
Mul,
Expand All @@ -34,7 +36,7 @@
Switch,
TrueDiv,
)
from pytensor.scalar.math import Sigmoid
from pytensor.scalar.math import Erfc, Erfcx, Sigmoid, Softplus
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.math import Dot

Expand Down Expand Up @@ -113,6 +115,14 @@ def true_div(x, y):
return true_div


@mlx_funcify_Elemwise_scalar_op.register(IntDiv)
def _(scalar_op):
def int_div(x, y):
return mx.floor_divide(x, y)

return int_div


@mlx_funcify_Elemwise_scalar_op.register(Pow)
def _(scalar_op):
def pow(x, y):
Expand Down Expand Up @@ -309,11 +319,51 @@ def sigmoid(x):
@mlx_funcify_Elemwise_scalar_op.register(Invert)
def _(scalar_op):
def invert(x):
return ~x
return mx.bitwise_invert(x)

return invert


@mlx_funcify_Elemwise_scalar_op.register(IsNan)
def _(scalar_op):
def isnan(x):
return mx.isnan(x)

return isnan


@mlx_funcify_Elemwise_scalar_op.register(Erfc)
def _(scalar_op):
def erfc(x):
return 1.0 - mx.erf(x)

return erfc


@mlx_funcify_Elemwise_scalar_op.register(Erfcx)
def _(scalar_op):
def erfcx(x):
return mx.exp(x * x) * (1.0 - mx.erf(x))

return erfcx


@mlx_funcify_Elemwise_scalar_op.register(Softplus)
def _(scalar_op):
def softplus(x):
# Numerically stable implementation of log(1 + exp(x))
# Following the same logic as the original PyTensor implementation
return mx.where(
x < -37.0,
mx.exp(x),
mx.where(
x < 18.0, mx.log1p(mx.exp(x)), mx.where(x < 33.3, x + mx.exp(-x), x)
),
)

return softplus


@mlx_funcify.register(Elemwise)
def mlx_funcify_Elemwise(op, node, **kwargs):
# Dispatch to the appropriate scalar op handler
Expand Down
12 changes: 11 additions & 1 deletion pytensor/link/mlx/dispatch/shape.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import mlx.core as mx

from pytensor.link.mlx.dispatch.basic import mlx_funcify
from pytensor.tensor.shape import Shape, Shape_i, SpecifyShape
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape


@mlx_funcify.register(Shape)
Expand Down Expand Up @@ -30,3 +32,11 @@ def shape_i(x):
return x.shape[op.i]

return shape_i


@mlx_funcify.register(Reshape)
def mlx_funcify_Reshape(op, **kwargs):
def reshape(x, shp):
return mx.reshape(x, shp)

return reshape
56 changes: 53 additions & 3 deletions tests/link/mlx/test_basic.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,19 @@
from collections.abc import Callable, Iterable
from functools import partial

import mlx.core as mx
import numpy as np
import pytest

import pytensor
from pytensor import tensor as pt
from pytensor.compile.function import function
from pytensor.compile.mode import MLX, Mode
from pytensor.graph import RewriteDatabaseQuery
from pytensor.graph.basic import Variable
from pytensor.link.mlx import MLXLinker
from pytensor.link.mlx.dispatch.core import mlx_funcify_ScalarFromTensor


mx = pytest.importorskip("mlx.core")

optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude)
mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer)
py_mode = Mode(linker="py", optimizer=None)
Expand Down Expand Up @@ -78,3 +79,52 @@ def compare_mlx_and_py(
assert_fn(mlx_res, py_res)

return pytensor_mlx_fn, mlx_res


def test_scalar_from_tensor_with_scalars():
"""Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars.

This addresses the AttributeError that occurred when Python integers were
passed to ScalarFromTensor instead of MLX arrays.
"""
scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None)

# Test with MLX array
mlx_array = mx.array([42])
result = scalar_from_tensor_func(mlx_array)
assert result == 42

# Test with Python int (this used to fail)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are python ints being passed? That suggests a bug elsewhere

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Indeed, not bug but didn't replace a few things, done!

python_int = 42
result = scalar_from_tensor_func(python_int)
assert result == 42

# Test with Python float
python_float = 3.14
result = scalar_from_tensor_func(python_float)
assert abs(result - 3.14) < 1e-6

# Test with NumPy scalar
numpy_scalar = np.int32(123)
result = scalar_from_tensor_func(numpy_scalar)
assert result == 123

# Test with NumPy float scalar
numpy_float = np.float32(2.71)
result = scalar_from_tensor_func(numpy_float)
assert abs(result - 2.71) < 1e-6


def test_scalar_from_tensor_pytensor_integration():
"""Test ScalarFromTensor in a PyTensor graph context."""
# Create a 0-d tensor (scalar tensor)
x = pt.as_tensor_variable(42)

# Apply ScalarFromTensor
scalar_result = pt.scalar_from_tensor(x)

# Create function and test
f = pytensor.function([], scalar_result, mode="MLX")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This won't test MLX at all because it's a constant function and you are running full optimization, so it will just be constant folded. Instead make x a symbolic variable like x = pytensor.scalar.int64(x) that is an input to the function.

Copy link

@cetagostini cetagostini Jun 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Totally right, miss that 🙌🏻

result = f()

assert result == 42
37 changes: 37 additions & 0 deletions tests/link/mlx/test_elemwise.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import numpy as np
import pytest

import pytensor.tensor as pt
Expand All @@ -11,3 +12,39 @@ def test_input(op) -> None:
x_test = mx.array([1.0, 2.0, 3.0])

compare_mlx_and_py([x], out, [x_test])


def test_new_elemwise_operations() -> None:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bad test name, they won't be new by the time this PR is merged ;)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done!

"""Test new elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context"""
x = pt.vector("x")
y = pt.vector("y")

# Test int_div in an elemwise expression
out_int_div = pt.int_div(x, y) + 1
x_test = mx.array([10.0, 15.0, 20.0])
y_test = mx.array([3.0, 4.0, 6.0])
compare_mlx_and_py([x, y], out_int_div, [x_test, y_test])

# Test isnan in an elemwise expression
z = pt.vector("z")
out_isnan = pt.isnan(z).astype("float32") * 10
z_test = mx.array([1.0, np.nan, 3.0])
compare_mlx_and_py([z], out_isnan, [z_test])

# Test erfc in an elemwise expression
w = pt.vector("w")
out_erfc = pt.erfc(w) * 2.0
w_test = mx.array([0.0, 0.5, 1.0])
compare_mlx_and_py([w], out_erfc, [w_test])

# Test erfcx in an elemwise expression
v = pt.vector("v")
out_erfcx = pt.erfcx(v) + 0.1
v_test = mx.array([0.0, 1.0, 2.0])
compare_mlx_and_py([v], out_erfcx, [v_test])

# Test softplus in an elemwise expression
u = pt.vector("u")
out_softplus = pt.softplus(u) - 0.5
u_test = mx.array([0.0, 1.0, -1.0])
compare_mlx_and_py([u], out_softplus, [u_test])
114 changes: 114 additions & 0 deletions tests/link/mlx/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,7 @@ def test_input(op) -> None:
pytest.param(pt.eq, id="eq"),
pytest.param(pt.neq, id="neq"),
pytest.param(pt.true_div, id="true_div"),
pytest.param(pt.int_div, id="int_div"),
],
)
def test_elemwise_two_inputs(op) -> None:
Expand All @@ -90,6 +91,119 @@ def test_elemwise_two_inputs(op) -> None:
compare_mlx_and_py([x, y], out, [x_test, y_test])


def test_int_div_specific() -> None:
"""Test integer division with specific test cases"""
x = pt.vector("x")
y = pt.vector("y")
out = pt.int_div(x, y)

# Test with integers that demonstrate floor division behavior
x_test = mx.array([7.0, 8.0, 9.0, -7.0, -8.0])
y_test = mx.array([3.0, 3.0, 3.0, 3.0, 3.0])

compare_mlx_and_py([x, y], out, [x_test, y_test])


def test_isnan() -> None:
"""Test IsNan operation with various inputs including NaN values"""
x = pt.vector("x")
out = pt.isnan(x)

# Test with mix of normal values, NaN, and infinity
x_test = mx.array([1.0, np.nan, 3.0, np.inf, -np.nan, 0.0, -np.inf])

compare_mlx_and_py([x], out, [x_test])


def test_isnan_edge_cases() -> None:
"""Test IsNan with edge cases"""
x = pt.scalar("x")
out = pt.isnan(x)

# Test individual cases
test_cases = [0.0, np.nan, np.inf, -np.inf, 1e-10, 1e10]

for test_val in test_cases:
x_test = test_val
compare_mlx_and_py([x], out, [x_test])


def test_erfc() -> None:
"""Test complementary error function"""
x = pt.vector("x")
out = pt.erfc(x)

# Test with various values including negative, positive, and zero
x_test = mx.array([0.0, 0.5, 1.0, -0.5, -1.0, 2.0, -2.0, 0.1])

compare_mlx_and_py([x], out, [x_test])


def test_erfc_extreme_values() -> None:
"""Test erfc with extreme values"""
x = pt.vector("x")
out = pt.erfc(x)

# Test with larger values where erfc approaches 0 or 2
x_test = mx.array([-3.0, -2.5, 2.5, 3.0])

# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial

relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-3, atol=1e-6)

compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)


def test_erfcx() -> None:
"""Test scaled complementary error function"""
x = pt.vector("x")
out = pt.erfcx(x)

# Test with positive values where erfcx is most numerically stable
x_test = mx.array([0.0, 0.5, 1.0, 1.5, 2.0, 2.5])

compare_mlx_and_py([x], out, [x_test])


def test_erfcx_small_values() -> None:
"""Test erfcx with small values"""
x = pt.vector("x")
out = pt.erfcx(x)

# Test with small values
x_test = mx.array([0.001, 0.01, 0.1, 0.2])

compare_mlx_and_py([x], out, [x_test])


def test_softplus() -> None:
"""Test softplus (log(1 + exp(x))) function"""
x = pt.vector("x")
out = pt.softplus(x)

# Test with normal range values
x_test = mx.array([0.0, 1.0, 2.0, -1.0, -2.0, 10.0])

compare_mlx_and_py([x], out, [x_test])


def test_softplus_extreme_values() -> None:
"""Test softplus with extreme values to verify numerical stability"""
x = pt.vector("x")
out = pt.softplus(x)

# Test with extreme values where different branches of the implementation are used
x_test = mx.array([-40.0, -50.0, 20.0, 30.0, 35.0, 50.0])

# Use relaxed tolerance for extreme values due to numerical precision differences
from functools import partial

relaxed_assert = partial(np.testing.assert_allclose, rtol=1e-4, atol=1e-8)

compare_mlx_and_py([x], out, [x_test], assert_fn=relaxed_assert)


@pytest.mark.xfail(reason="Argmax not implemented yet")
def test_mlx_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
Expand Down
Loading