-
Notifications
You must be signed in to change notification settings - Fork 3.1k
Closed
Labels
Description
Description
import jax
import jax.numpy as jnp
def f(x):
return jax.lax.argmax(x, 0, index_dtype=jnp.uint32)
x = jnp.array([2, 3, 1])
# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<ui32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
%0 = call @argmax(%arg0) : (tensor<3xi32>) -> tensor<ui32>
return %0 : tensor<ui32>
}
func.func private @argmax(%arg0: tensor<3xi32>) -> tensor<ui32> {
%0 = stablehlo.iota dim = 0 : tensor<3xui32>
%c = stablehlo.constant dense<-2147483648> : tensor<i32>
%c_0 = stablehlo.constant dense<0> : tensor<ui32>
%1:2 = stablehlo.reduce(%arg0 init: %c), (%0 init: %c_0) across dimensions = [0] : (tensor<3xi32>, tensor<3xui32>, tensor<i32>, tensor<ui32>) -> (tensor<i32>, tensor<ui32>)
reducer(%arg1: tensor<i32>, %arg3: tensor<i32>) (%arg2: tensor<ui32>, %arg4: tensor<ui32>) {
%2 = stablehlo.compare GT, %arg1, %arg3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%3 = stablehlo.compare NE, %arg1, %arg1, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%4 = stablehlo.or %2, %3 : tensor<i1>
%5 = stablehlo.compare EQ, %arg1, %arg3, SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
%6 = stablehlo.compare LT, %arg2, %arg4, UNSIGNED : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
%7 = stablehlo.and %5, %6 : tensor<i1>
%8 = stablehlo.or %4, %7 : tensor<i1>
%9 = stablehlo.select %4, %arg1, %arg3 : tensor<i1>, tensor<i32>
%10 = stablehlo.select %8, %arg2, %arg4 : tensor<i1>, tensor<ui32>
stablehlo.return %9, %10 : tensor<i32>, tensor<ui32>
}
return %1#1 : tensor<ui32>
}
}
Fails with:
Traceback (most recent call last):
File "/Users/jonatanklosko/git/nx/exla/tmp/jax_dbg_mps.py", line 156, in <module>
print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<si32>') doesn't match function result type ('tensor<ui32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%6) : (tensor<si32>) -> ()
Currently jax-metal matches on the reduce IR and rewrites into specific argmax, perhaps it should add a convert if the output type is different.
System info (python version, jaxlib version, accelerator, etc.)
jax: 0.4.26
jaxlib: 0.4.26
numpy: 1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May 1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')
jax-metal 0.0.7