2

This is a follow-up to a previous question about the jax.lax.fori_loop function, with a little bit of a challenge for you at the end.

As described in the documentation, the fori_loop is never executed at runtime for the case upper<=lower. However, as has been pointed out several times, it is still traced. This can cause issues with out-of-bound indexing. I understand that the consensus is that this is intended behavior for fori_loop.

Nevertheless, in my use cases, the python-like behavior makes things much, much easier conceptually. So in my previous question, I came up with the following wrapper that overrides the default behavior when the indexing issue occurs:

import jax.numpy as jnp import jax from jax.scipy.special import gammaln # WRAPPER FOR FORI TO HANDLE THE CASE UPPER<=LOWER SEPARATELY def wrapped_fori(lower, upper, body_fun, init_val, unroll=None): if upper<=lower: out = init_val else: out = jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) return out def comb(n, k): return jnp.round(jnp.exp(gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1))) def binom_conv(n, Aks, Bks): return part_binom_conv(n, 0, n, Aks, Bks) def part_binom_conv(n, k0, k1, Aks, Bks): A_shape = Aks.shape[1:] A_dtype = Aks.dtype init_conv = jnp.zeros(A_shape, dtype=A_dtype) conv = jax.lax.fori_loop(k0, k1, update_binom_conv, (init_conv, n, Aks, Bks))[0] return conv def update_binom_conv(k, val): conv, n, Aks, Bks = val conv = conv + comb(n-1, k) * Aks[k] @ Bks[(n-1)-k] return conv, n, Aks, Bks @jax.jit def build(U, Hks): n = Hks.shape[0] # n=0 H_shape = Hks.shape[1:] # H_shape=(2,2) Uks_shape = (n+1,)+H_shape # Uks_shape=(1,2,2) Uks = jnp.zeros(Uks_shape, dtype=Hks.dtype) Uks = Uks.at[0].set(U) Uks = wrapped_fori(0, n, update_Uks, (Uks, Hks))[0] # Treats the case n=0 separately return Uks def update_Uks(k, val): Uks, Hks = val Uks = Uks.at[k+1].set(-1j*binom_conv(k+1, Hks, Uks)) return Uks, Hks # Test Hks = jnp.zeros((3,2,2), dtype=complex) U = jnp.eye(2, dtype=complex) build(U, Hks) 

The above works fine. However, I noticed that I can't replace all my fori_loops with this wrapper. Specifically, it fails when used with nested loops. For example, the following modification of the function part_binom_conv() fails:

import jax.numpy as jnp import jax from jax.scipy.special import gammaln # # WRAPPER FOR FORI TO HANDLE THE CASE UPPER<=LOWER SEPARATELY def wrapped_fori(lower, upper, body_fun, init_val, unroll=None): if upper<=lower: out = init_val else: out = jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) return out def comb(n, k): return jnp.round(jnp.exp(gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1))) def binom_conv(n, Aks, Bks): return part_binom_conv(n, 0, n, Aks, Bks) def part_binom_conv(n, k0, k1, Aks, Bks): A_shape = Aks.shape[1:] A_dtype = Aks.dtype init_conv = jnp.zeros(A_shape, dtype=A_dtype) conv = wrapped_fori(k0, k1, update_binom_conv, (init_conv, n, Aks, Bks))[0] #<--- This causes an error return conv def update_binom_conv(k, val): conv, n, Aks, Bks = val conv = conv + comb(n-1, k) * Aks[k] @ Bks[(n-1)-k] return conv, n, Aks, Bks @jax.jit def build(U, Hks): n = Hks.shape[0] # n=0 H_shape = Hks.shape[1:] # H_shape=(2,2) Uks_shape = (n+1,)+H_shape # Uks_shape=(1,2,2) Uks = jnp.zeros(Uks_shape, dtype=Hks.dtype) Uks = Uks.at[0].set(U) Uks = wrapped_fori(0, n, update_Uks, (Uks, Hks))[0] # Treats the case n=0 separately return Uks def update_Uks(k, val): Uks, Hks = val Uks = Uks.at[k+1].set(-1j*binom_conv(k+1, Hks, Uks)) return Uks, Hks # Test Hks = jnp.zeros((3,2,2), dtype=complex) U = jnp.eye(2, dtype=complex) build(U, Hks) 

The error is a TracerBoolConversionError which I think is related to the tracing the condition in my wrapper:

--------------------------------------------------------------------------- TracerBoolConversionError Traceback (most recent call last) Cell In[4], line 55 53 Hks = jnp.zeros((3,2,2), dtype=complex) 54 U = jnp.eye(2, dtype=complex) ---> 55 build(U, Hks) [... skipping hidden 13 frame] Cell In[4], line 43 41 Uks = jnp.zeros(Uks_shape, dtype=Hks.dtype) 42 Uks = Uks.at[0].set(U) ---> 43 Uks = wrapped_fori(0, n, update_Uks, (Uks, Hks))[0] # Treats the case n=0 separately 44 return Uks Cell In[4], line 10 8 out = init_val 9 else: ---> 10 out = jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) 11 return out [... skipping hidden 12 frame] Cell In[4], line 48 46 def update_Uks(k, val): ... -> 1806 raise TracerBoolConversionError(arg) TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function update_Uks at /var/folders/x0/28x522xx1vb2xl75tn781lqr0000gn/T/ipykernel_54810/1590930335.py:46 for fori_loop. This concrete value was not available in Python because it depends on the value of the argument k. See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError 

My question is a little bit of a challenge. Is it possible to modify this wrapper for the fori_loop so that it doesn't trace the body when upper<=lower, and that it never causes an error in nested loops?

I understand that this will not be implemented in jax, but I was wondering if it is something I could do in my code.

1 Answer 1

1

Is it possible to modify this wrapper for the fori_loop so that it doesn't trace the body when upper<=lower...

No, I don't believe that is possible.

The problematic case you point out occurs when the fori_loop start and endpoints are traced, in which case their concrete values are by definition unknown at trace-time. You cannot condition tracing behavior on values that are not known at trace time.

... and that it never causes an error in nested loops?

I don't think you need to worry about this. The reason your previous question ran into an error is because you were in a situation where the array shapes were related to the loop length, and so for loop length zero, indexing into the array failed. With dynamic loop endpoints, the array shapes cannot be related to the loop length, because shapes cannot be dynamic. So I don't think you'd ever run into an issue where tracing a zero-length dynamic/inner loop causes problems, unless your code had a bug such that it would error in all cases.

Sign up to request clarification or add additional context in comments.

1 Comment

This is very insightful, and I think a convincing proof that my previous solution is a general enough one within jax. I swear to God, one day I will stop poking my eye with jax's sharp edges.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.