-
Notifications
You must be signed in to change notification settings - Fork 137
MLX backend POC #1365
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
MLX backend POC #1365
Changes from 1 commit
d25f214
edacc0e
052fdc2
a9ecad0
e690bff
ad29c17
ba29b37
8716870
9bf7edf
96ba116
e116fa1
5d5f754
b8cee3f
d057453
ae202e6
f1941fe
7c8eae7
67a74fb
fb5eb52
516b595
67bb8da
82bb964
877d79f
fafedd6
60acb8d
242aba7
6cb47fc
bc98e09
4d5b34b
5abd32d
12daeac
ac93949
a19cbc8
5c97bc8
2fc81bc
e6437cc
c3a3e1a
e7cf10e
880dd5c
1e6addd
aabbb78
0812c55
294c271
9d3eca8
9f31ab1
37440ff
4e4923f
6b27dc4
e308f83
3744a18
b41cab0
323fa9d
9766975
5ffc5ef
597f84e
fb8fd2f
6a2b774
602f0ed
8a2aea9
845561c
6ab7428
b6292f1
f2d9d1b
5c759b9
97b2e31
dd83e0f
662b4f2
bcf7f8d
e706171
929630b
8f2982d
02ed254
06ccf91
03a2094
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,18 +1,19 @@ | ||
from collections.abc import Callable, Iterable | ||
from functools import partial | ||
|
||
import mlx.core as mx | ||
import numpy as np | ||
import pytest | ||
|
||
import pytensor | ||
from pytensor import tensor as pt | ||
from pytensor.compile.function import function | ||
from pytensor.compile.mode import MLX, Mode | ||
from pytensor.graph import RewriteDatabaseQuery | ||
from pytensor.graph.basic import Variable | ||
from pytensor.link.mlx import MLXLinker | ||
from pytensor.link.mlx.dispatch.core import mlx_funcify_ScalarFromTensor | ||
|
||
|
||
mx = pytest.importorskip("mlx.core") | ||
|
||
optimizer = RewriteDatabaseQuery(include=["mlx"], exclude=MLX._optimizer.exclude) | ||
mlx_mode = Mode(linker=MLXLinker(), optimizer=optimizer) | ||
py_mode = Mode(linker="py", optimizer=None) | ||
|
@@ -78,3 +79,52 @@ def compare_mlx_and_py( | |
assert_fn(mlx_res, py_res) | ||
|
||
return pytensor_mlx_fn, mlx_res | ||
|
||
|
||
def test_scalar_from_tensor_with_scalars(): | ||
"""Test ScalarFromTensor works with both MLX arrays and Python/NumPy scalars. | ||
|
||
This addresses the AttributeError that occurred when Python integers were | ||
passed to ScalarFromTensor instead of MLX arrays. | ||
""" | ||
scalar_from_tensor_func = mlx_funcify_ScalarFromTensor(None) | ||
|
||
# Test with MLX array | ||
mlx_array = mx.array([42]) | ||
result = scalar_from_tensor_func(mlx_array) | ||
assert result == 42 | ||
|
||
# Test with Python int (this used to fail) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Why are python ints being passed? That suggests a bug elsewhere There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Indeed, not bug but didn't replace a few things, done! |
||
python_int = 42 | ||
result = scalar_from_tensor_func(python_int) | ||
assert result == 42 | ||
|
||
# Test with Python float | ||
python_float = 3.14 | ||
result = scalar_from_tensor_func(python_float) | ||
assert abs(result - 3.14) < 1e-6 | ||
|
||
# Test with NumPy scalar | ||
numpy_scalar = np.int32(123) | ||
result = scalar_from_tensor_func(numpy_scalar) | ||
assert result == 123 | ||
|
||
# Test with NumPy float scalar | ||
numpy_float = np.float32(2.71) | ||
result = scalar_from_tensor_func(numpy_float) | ||
assert abs(result - 2.71) < 1e-6 | ||
|
||
|
||
def test_scalar_from_tensor_pytensor_integration(): | ||
"""Test ScalarFromTensor in a PyTensor graph context.""" | ||
# Create a 0-d tensor (scalar tensor) | ||
x = pt.as_tensor_variable(42) | ||
|
||
# Apply ScalarFromTensor | ||
scalar_result = pt.scalar_from_tensor(x) | ||
|
||
# Create function and test | ||
f = pytensor.function([], scalar_result, mode="MLX") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This won't test MLX at all because it's a constant function and you are running full optimization, so it will just be constant folded. Instead make There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Totally right, miss that 🙌🏻 |
||
result = f() | ||
|
||
assert result == 42 |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,3 +1,4 @@ | ||
import numpy as np | ||
import pytest | ||
|
||
import pytensor.tensor as pt | ||
|
@@ -11,3 +12,39 @@ def test_input(op) -> None: | |
x_test = mx.array([1.0, 2.0, 3.0]) | ||
|
||
compare_mlx_and_py([x], out, [x_test]) | ||
|
||
|
||
def test_new_elemwise_operations() -> None: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. bad test name, they won't be new by the time this PR is merged ;) There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Done! |
||
"""Test new elemwise operations (IntDiv, IsNan, Erfc, Erfcx, Softplus) in elemwise context""" | ||
x = pt.vector("x") | ||
y = pt.vector("y") | ||
|
||
# Test int_div in an elemwise expression | ||
out_int_div = pt.int_div(x, y) + 1 | ||
x_test = mx.array([10.0, 15.0, 20.0]) | ||
y_test = mx.array([3.0, 4.0, 6.0]) | ||
compare_mlx_and_py([x, y], out_int_div, [x_test, y_test]) | ||
|
||
# Test isnan in an elemwise expression | ||
z = pt.vector("z") | ||
out_isnan = pt.isnan(z).astype("float32") * 10 | ||
z_test = mx.array([1.0, np.nan, 3.0]) | ||
compare_mlx_and_py([z], out_isnan, [z_test]) | ||
|
||
# Test erfc in an elemwise expression | ||
w = pt.vector("w") | ||
out_erfc = pt.erfc(w) * 2.0 | ||
w_test = mx.array([0.0, 0.5, 1.0]) | ||
compare_mlx_and_py([w], out_erfc, [w_test]) | ||
|
||
# Test erfcx in an elemwise expression | ||
v = pt.vector("v") | ||
out_erfcx = pt.erfcx(v) + 0.1 | ||
v_test = mx.array([0.0, 1.0, 2.0]) | ||
compare_mlx_and_py([v], out_erfcx, [v_test]) | ||
|
||
# Test softplus in an elemwise expression | ||
u = pt.vector("u") | ||
out_softplus = pt.softplus(u) - 0.5 | ||
u_test = mx.array([0.0, 1.0, -1.0]) | ||
compare_mlx_and_py([u], out_softplus, [u_test]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems convoluted. Does MLX have something like
x.item()
?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I need to check this again. One minute.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It has but we have issues with the
pytensor-mlx
way to compile, and this make it generic enough.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Optimize to have both options, hope you like that more!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done!