Skip to content

jax-metal: argmax fails with unsigned index type #21577

@jonatanklosko

Description

@jonatanklosko

Description

import jax
import jax.numpy as jnp

def f(x):
  return jax.lax.argmax(x, 0, index_dtype=jnp.uint32)

x = jnp.array([2, 3, 1])

# Print lowered HLO
print(jax.jit(f).lower(x).as_text())
print(jax.jit(f)(x))
HLO
module @jit_f attributes {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32} {
  func.func public @main(%arg0: tensor<3xi32> {mhlo.layout_mode = "default"}) -> (tensor<ui32> {jax.result_info = "", mhlo.layout_mode = "default"}) {
    %0 = call @argmax(%arg0) : (tensor<3xi32>) -> tensor<ui32>
    return %0 : tensor<ui32>
  }
  func.func private @argmax(%arg0: tensor<3xi32>) -> tensor<ui32> {
    %0 = stablehlo.iota dim = 0 : tensor<3xui32>
    %c = stablehlo.constant dense<-2147483648> : tensor<i32>
    %c_0 = stablehlo.constant dense<0> : tensor<ui32>
    %1:2 = stablehlo.reduce(%arg0 init: %c), (%0 init: %c_0) across dimensions = [0] : (tensor<3xi32>, tensor<3xui32>, tensor<i32>, tensor<ui32>) -> (tensor<i32>, tensor<ui32>)
     reducer(%arg1: tensor<i32>, %arg3: tensor<i32>) (%arg2: tensor<ui32>, %arg4: tensor<ui32>)  {
      %2 = stablehlo.compare  GT, %arg1, %arg3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %3 = stablehlo.compare  NE, %arg1, %arg1,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %4 = stablehlo.or %2, %3 : tensor<i1>
      %5 = stablehlo.compare  EQ, %arg1, %arg3,  SIGNED : (tensor<i32>, tensor<i32>) -> tensor<i1>
      %6 = stablehlo.compare  LT, %arg2, %arg4,  UNSIGNED : (tensor<ui32>, tensor<ui32>) -> tensor<i1>
      %7 = stablehlo.and %5, %6 : tensor<i1>
      %8 = stablehlo.or %4, %7 : tensor<i1>
      %9 = stablehlo.select %4, %arg1, %arg3 : tensor<i1>, tensor<i32>
      %10 = stablehlo.select %8, %arg2, %arg4 : tensor<i1>, tensor<ui32>
      stablehlo.return %9, %10 : tensor<i32>, tensor<ui32>
    }
    return %1#1 : tensor<ui32>
  }
}

Fails with:

Traceback (most recent call last):
  File "/Users/jonatanklosko/git/nx/exla/tmp/jax_dbg_mps.py", line 156, in <module>
    print(jax.jit(f)(x))
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<si32>') doesn't match function result type ('tensor<ui32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%6) : (tensor<si32>) -> ()

Currently jax-metal matches on the reduce IR and rewrites into specific argmax, perhaps it should add a convert if the output type is different.

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.26
numpy:  1.26.4
python: 3.10.8 (main, Nov 16 2022, 12:45:33) [Clang 14.0.0 (clang-1400.0.29.202)]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='chonker', release='23.5.0', version='Darwin Kernel Version 23.5.0: Wed May  1 20:12:58 PDT 2024; root:xnu-10063.121.3~5/RELEASE_ARM64_T6000', machine='arm64')

jax-metal 0.0.7

Metadata

Metadata

Assignees

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions