Skip to content

JIT constant folding #21300

@inversecrime

Description

@inversecrime

Description

Hi,
I was hoping that someone could help me with this.

Sometimes, when using constants in jitted functions, I get warnings like this one:

2024-05-19 20:16:26.694439: E external/xla/xla/service/slow_operation_alarm.cc:65] Constant folding an instruction is taking > 1s:

  %reduce.8 = f64[200000,10]{1,0} reduce(f64[200000,10,10]{2,1,0} %broadcast.2, f64[] %constant.3), dimensions={2}, to_apply=%region_0.4, metadata={op_name="jit(f)/jit(main)/reduce_sum[axes=(2,)]" source_file="..." source_line=13}

This isn't necessarily a bug; constant-folding is inherently a trade-off between compilation time and speed at runtime. XLA has some guards that attempt to keep constant folding from taking too long, but fundamentally you'll always be able to come up with an input program that takes a long time.

If you'd like to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.

These warnings appear seemingly random, for example with the following code:

from functools import wraps
import jax
import jax.numpy as jnp
import jax.core

jax.config.update("jax_enable_x64", True)
jax.config.update("jax_platforms", "cpu")

v = jnp.zeros((200000, 10, 10))


def f():
    return jax.vmap(jax.vmap(jnp.sum))(v)


jax.jit(f)()

This code produces "constant folding" warnings on windows and on linux. Maybe / probably this is dependend on OS version, CPU type, ...

When playing around with array shapes and number of nested vmaps, these messages appear or not appear without any clear (atleast not clear to me) pattern. For exampe, this is fast:

v = jnp.zeros((1000000, 10, 10))
def f():
    return jax.vmap(jnp.sum)(v)
jax.jit(f)()

While this is slow and produces the warning:

v = jnp.zeros((1000000, 10, 2))
def f():
    return jax.vmap(jnp.sum)(v)
jax.jit(f)()

Constant folding only happens when compiling with jax.jit - making jaxprs is not affected.
Since jaxprs are perfectly able to catch constants, it is possible to compile them while treating constants as variables.
The following function demonstrates this:

def other_jit(f):
    @wraps(f)
    def wrapper(*args):
        jaxpr = jax.make_jaxpr(f)(*args)
        return jax.jit(lambda c, *a: jax.core.eval_jaxpr(jaxpr.jaxpr, c, *a))(jaxpr.consts, *args)
    return wrapper

Now, using other_jit(f)() instead of jax.jit(f)() prevents the issue.

I was wondering if this is intended behavior.
Wouldn't it be a better solution in most cases to always treat constants as variables while compiling, to prevent constant folding from slowing down compilations?

In real-world scenarios, using (a generalized version of) the other_jit function I presented here can significantly reduce compilation times from a few minutes to just seconds.

What's your opinion on this? I would appreciate any help or suggestions.

System info (python version, jaxlib version, accelerator, etc.)

cpu
jax 0.4.28
jaxlib 0.4.28

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't working

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions