Skip to content

Commit 5a50406

Browse files
committed
(Experimental) Integrate Metal PjRt plugin
1 parent b2fdb9a commit 5a50406

File tree

12 files changed

+170
-8
lines changed

12 files changed

+170
-8
lines changed

exla/c_src/exla/exla.cc

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,10 @@ ERL_NIF_TERM mlir_compile(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
144144
build_options.set_use_spmd_partitioning(use_spmd);
145145

146146
bool compile_portable_executable = false;
147-
if (device_id >= 0) {
147+
148+
bool is_mps = (*client)->client()->platform_name() == "METAL";
149+
150+
if (device_id >= 0 && !is_mps) {
148151
compile_portable_executable = true;
149152
build_options.set_device_ordinal(device_id);
150153
}
@@ -877,6 +880,16 @@ ERL_NIF_TERM get_tpu_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[])
877880
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
878881
}
879882

883+
ERL_NIF_TERM get_mps_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
884+
if (argc != 0) {
885+
return exla::nif::error(env, "Bad argument count.");
886+
}
887+
888+
EXLA_ASSIGN_OR_RETURN_NIF(exla::ExlaClient * client, exla::GetMpsClient(), env);
889+
890+
return exla::nif::ok(env, exla::nif::make<exla::ExlaClient*>(env, client));
891+
}
892+
880893
ERL_NIF_TERM get_c_api_client(ErlNifEnv* env, int argc, const ERL_NIF_TERM argv[]) {
881894
if (argc != 1) {
882895
return exla::nif::error(env, "Bad argument count.");
@@ -1065,6 +1078,7 @@ static ErlNifFunc exla_funcs[] = {
10651078
{"get_host_client", 0, get_host_client},
10661079
{"get_gpu_client", 2, get_gpu_client},
10671080
{"get_tpu_client", 0, get_tpu_client},
1081+
{"get_mps_client", 0, get_mps_client},
10681082
{"get_c_api_client", 1, get_c_api_client},
10691083
{"load_pjrt_plugin", 2, load_pjrt_plugin},
10701084
{"get_device_count", 1, get_device_count},

exla/c_src/exla/exla_client.cc

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -495,6 +495,30 @@ xla::StatusOr<ExlaClient*> GetTpuClient() {
495495
return new ExlaClient(std::move(client));
496496
}
497497

498+
xla::StatusOr<ExlaClient*> GetMpsClient() {
499+
// The plugin may be compiled for a different version of PjRt C API
500+
// than present in our XLA compilation. By default pjrt::LoadPjrtPlugin
501+
// raises if the version does not match. By setting this environment
502+
// variable, we relax this check to allow different versions, as long
503+
// as they satisfy compatibility constraints.
504+
//
505+
// See https://github.com/openxla/xla/blob/4e8e23f16bc925b6f27817de098a8e1e81296bb5/xla/pjrt/pjrt_api.cc
506+
setenv("ENABLE_PJRT_COMPATIBILITY", "1", 1);
507+
508+
EXLA_ASSIGN_OR_RETURN(const PJRT_Api* pjrt_api, pjrt::LoadPjrtPlugin("METAL", "pjrt_plugin_metal.dylib"));
509+
510+
xla::Status status = pjrt::InitializePjrtPlugin("METAL");
511+
512+
if (!status.ok()) {
513+
return status;
514+
}
515+
516+
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
517+
xla::GetCApiClient("METAL"));
518+
519+
return new ExlaClient(std::move(client));
520+
}
521+
498522
xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type) {
499523
EXLA_ASSIGN_OR_RETURN(std::unique_ptr<xla::PjRtClient> client,
500524
xla::GetCApiClient(device_type));

exla/c_src/exla/exla_client.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ xla::StatusOr<ExlaClient*> GetGpuClient(double memory_fraction,
110110

111111
xla::StatusOr<ExlaClient*> GetTpuClient();
112112

113+
xla::StatusOr<ExlaClient*> GetMpsClient();
114+
113115
xla::StatusOr<ExlaClient*> GetCApiClient(std::string device_type);
114116
} // namespace exla
115117

exla/lib/exla/client.ex

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,9 @@ defmodule EXLA.Client do
159159
:tpu ->
160160
EXLA.NIF.get_tpu_client()
161161

162+
:mps ->
163+
EXLA.NIF.get_mps_client()
164+
162165
_ ->
163166
raise ArgumentError, "unknown EXLA platform: #{inspect(platform)}"
164167
end

exla/lib/exla/defn.ex

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -712,6 +712,10 @@ defmodule EXLA.Defn do
712712
) do
713713
precision = state.precision
714714

715+
# Ensure both have the same type
716+
left = to_type(left, ans.type)
717+
right = to_type(right, ans.type)
718+
715719
Value.dot_general(
716720
left,
717721
right,

exla/lib/exla/mlir/value.ex

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -672,9 +672,15 @@ defmodule EXLA.MLIR.Value do
672672
typespecs
673673
) do
674674
result_types = typespecs_to_mlir_types(typespecs)
675-
regions = [on_true, on_false]
676-
pred = convert(pred, Typespec.tensor({:pred, 8}, {}))
677-
op(func, "stablehlo.if", [pred], result_types, regions: regions)
675+
676+
# TODO Jax does not support stablehlo.if, they use stablhelo.case instead.
677+
# It most likely makes sense for use to do the same. That said, note that
678+
# stablehlo.case is implemented for Metal, but does not lower reliably.
679+
# Reported in https://github.com/google/jax/issues/21601
680+
681+
regions = [on_false, on_true]
682+
pred = convert(pred, Typespec.tensor({:s, 32}, {}))
683+
op(func, "stablehlo.case", [pred], result_types, regions: regions)
678684
end
679685

680686
def infeed(%Value{function: func} = token, typespecs) do

exla/lib/exla/nif.ex

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@ defmodule EXLA.NIF do
6767

6868
def get_tpu_client(), do: :erlang.nif_error(:undef)
6969

70+
def get_mps_client(), do: :erlang.nif_error(:undef)
71+
7072
def get_supported_platforms, do: :erlang.nif_error(:undef)
7173

7274
def get_device_count(_client),

exla/mix.exs

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,10 @@ defmodule EXLA.MixProject do
5252
cuda: [platform: :cuda],
5353
rocm: [platform: :rocm],
5454
tpu: [platform: :tpu],
55+
mps: [platform: :mps],
5556
host: [platform: :host]
5657
],
57-
preferred_clients: [:cuda, :rocm, :tpu, :host]
58+
preferred_clients: [:cuda, :rocm, :tpu, :mps, :host]
5859
]
5960
]
6061
end
@@ -129,11 +130,31 @@ defmodule EXLA.MixProject do
129130
:ok -> File.write!(xla_snapshot_path, xla_archive_path)
130131
{:error, term} -> Mix.raise("failed to extract xla archive, reason: #{inspect(term)}")
131132
end
133+
134+
# TODO should be packed into the XLA archive
135+
download_metal_plugin!(xla_extension_path)
132136
end
133137

134138
{:ok, []}
135139
end
136140

141+
defp download_metal_plugin!(xla_extension_path) do
142+
plugin_path = Path.join(xla_extension_path, "lib/pjrt_plugin_metal.dylib")
143+
144+
wheel_url =
145+
"https://files.pythonhosted.org/packages/09/dc/6d8fbfc29d902251cf333414cf7dcfaf4b252a9920c881354584ed36270d/jax_metal-0.1.1-py3-none-macosx_13_0_arm64.whl"
146+
147+
wheel_path = Path.join(xla_extension_path, "jax_metal.whl")
148+
149+
{_, 0} = System.shell("wget --output-document=#{wheel_path} #{wheel_url}")
150+
{_, 0} = System.shell("unzip #{wheel_path} -d #{xla_extension_path}")
151+
152+
wheel_plugin_path =
153+
Path.join(xla_extension_path, "jax_plugins/metal_plugin/pjrt_plugin_metal_14.dylib")
154+
155+
File.cp!(wheel_plugin_path, plugin_path)
156+
end
157+
137158
defp cached_make(args) do
138159
force_rebuild_env_var = System.get_env("EXLA_FORCE_REBUILD", "")
139160

exla/test/exla/backend_test.exs

Lines changed: 72 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,75 @@ defmodule EXLA.BackendTest do
2727
@skip_mac_arm []
2828
end
2929

30+
if EXLA.Client.default_name() == :mps do
31+
@skip_mps [
32+
# Missing support for "stablehlo.reduce_window".
33+
# Reported in https://github.com/google/jax/issues/21387
34+
window_max: 3,
35+
window_min: 3,
36+
window_sum: 3,
37+
window_product: 3,
38+
window_reduce: 5,
39+
window_scatter_min: 5,
40+
window_scatter_max: 5,
41+
window_mean: 3,
42+
# (edge case) Argmax/argmin return wrong value in case of NaN.
43+
# Reported in https://github.com/google/jax/issues/21821
44+
argmin: 2,
45+
argmax: 2,
46+
# Missing support for general "stablehlo.reduce". Some cases work
47+
# becuase they are special-cased.
48+
# Reported in https://github.com/google/jax/issues/21384
49+
reduce: 4,
50+
# Missing support for "stablehlo.popcnt", "stablehlo.count_leading_zeros",
51+
# "stablehlo.cbrt".
52+
# Reported in https://github.com/google/jax/issues/21389
53+
count_leading_zeros: 1,
54+
population_count: 1,
55+
cbrt: 1,
56+
# Matrix multiplication for integers is not supported
57+
dot: 2,
58+
dot: 4,
59+
dot: 6,
60+
covariance: 3,
61+
# (edge case) Put slice with overflowing slice, different behaviour.
62+
# Reported in https://github.com/google/jax/issues/21392
63+
put_slice: 3,
64+
# (edge case) Slice with overflowing index, different behaviour.
65+
# Reported in https://github.com/google/jax/issues/21393
66+
slice: 4,
67+
# (edge case) Top-k wrong behaviour with NaNs.
68+
# Reported in https://github.com/google/jax/issues/21397
69+
top_k: 2,
70+
# Missing support for complex numbers.
71+
# Tracked in https://github.com/google/jax/issues/16416
72+
complex: 2,
73+
conjugate: 1,
74+
conv: 3,
75+
fft: 2,
76+
fft2: 2,
77+
ifft: 2,
78+
ifft2: 2,
79+
imag: 1,
80+
is_infinity: 1,
81+
is_nan: 1,
82+
phase: 1,
83+
real: 1,
84+
sigil_MAT: 2,
85+
# Missing support for float-64.
86+
# Tracked in https://github.com/google/jax/issues/20938
87+
iota: 2,
88+
as_type: 2,
89+
atan2: 2,
90+
# Missing support for u2/s2
91+
bit_size: 1
92+
]
93+
else
94+
@skip_mps []
95+
end
96+
3097
doctest Nx,
31-
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm
98+
except: [:moduledoc] ++ @excluded_doctests ++ @skip_mac_arm ++ @skip_mps
3299

33100
test "Nx.to_binary/1" do
34101
t = Nx.tensor([1, 2, 3, 4], backend: EXLA.Backend)
@@ -199,6 +266,8 @@ defmodule EXLA.BackendTest do
199266
end
200267

201268
describe "quantized types" do
269+
# TODO mising support for s2
270+
@tag :skip
202271
test "s2" do
203272
tensor = Nx.s2(-1)
204273
assert <<-1::2-signed-native>> = Nx.to_binary(tensor)
@@ -237,6 +306,8 @@ defmodule EXLA.BackendTest do
237306
assert 28 = Nx.bit_size(tensor)
238307
end
239308

309+
# TODO mising support for u2
310+
@tag :skip
240311
test "u2" do
241312
tensor = Nx.u2(1)
242313
assert <<1::2-native>> = Nx.to_binary(tensor)

nx/lib/nx.ex

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7941,6 +7941,16 @@ defmodule Nx do
79417941
end
79427942
end
79437943

7944+
# TODO remove this, or make it an optinal callback
7945+
# (Metal does not support stablehlo.logistic yet)
7946+
def sigmoid(x) do
7947+
x
7948+
|> Nx.negate()
7949+
|> Nx.exp()
7950+
|> Nx.add(1)
7951+
|> then(&Nx.divide(1, &1))
7952+
end
7953+
79447954
## Unary ops
79457955
@disallow_complex_type_unary_ops [:erf, :erfc, :erf_inv]
79467956

0 commit comments

Comments
 (0)