This is a follow-up to a previous question about the jax.lax.fori_loop function, with a little bit of a challenge for you at the end.
As described in the documentation, the fori_loop is never executed at runtime for the case upper<=lower. However, as has been pointed out several times, it is still traced. This can cause issues with out-of-bound indexing. I understand that the consensus is that this is intended behavior for fori_loop.
Nevertheless, in my use cases, the python-like behavior makes things much, much easier conceptually. So in my previous question, I came up with the following wrapper that overrides the default behavior when the indexing issue occurs:
import jax.numpy as jnp import jax from jax.scipy.special import gammaln # WRAPPER FOR FORI TO HANDLE THE CASE UPPER<=LOWER SEPARATELY def wrapped_fori(lower, upper, body_fun, init_val, unroll=None): if upper<=lower: out = init_val else: out = jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) return out def comb(n, k): return jnp.round(jnp.exp(gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1))) def binom_conv(n, Aks, Bks): return part_binom_conv(n, 0, n, Aks, Bks) def part_binom_conv(n, k0, k1, Aks, Bks): A_shape = Aks.shape[1:] A_dtype = Aks.dtype init_conv = jnp.zeros(A_shape, dtype=A_dtype) conv = jax.lax.fori_loop(k0, k1, update_binom_conv, (init_conv, n, Aks, Bks))[0] return conv def update_binom_conv(k, val): conv, n, Aks, Bks = val conv = conv + comb(n-1, k) * Aks[k] @ Bks[(n-1)-k] return conv, n, Aks, Bks @jax.jit def build(U, Hks): n = Hks.shape[0] # n=0 H_shape = Hks.shape[1:] # H_shape=(2,2) Uks_shape = (n+1,)+H_shape # Uks_shape=(1,2,2) Uks = jnp.zeros(Uks_shape, dtype=Hks.dtype) Uks = Uks.at[0].set(U) Uks = wrapped_fori(0, n, update_Uks, (Uks, Hks))[0] # Treats the case n=0 separately return Uks def update_Uks(k, val): Uks, Hks = val Uks = Uks.at[k+1].set(-1j*binom_conv(k+1, Hks, Uks)) return Uks, Hks # Test Hks = jnp.zeros((3,2,2), dtype=complex) U = jnp.eye(2, dtype=complex) build(U, Hks) The above works fine. However, I noticed that I can't replace all my fori_loops with this wrapper. Specifically, it fails when used with nested loops. For example, the following modification of the function part_binom_conv() fails:
import jax.numpy as jnp import jax from jax.scipy.special import gammaln # # WRAPPER FOR FORI TO HANDLE THE CASE UPPER<=LOWER SEPARATELY def wrapped_fori(lower, upper, body_fun, init_val, unroll=None): if upper<=lower: out = init_val else: out = jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) return out def comb(n, k): return jnp.round(jnp.exp(gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1))) def binom_conv(n, Aks, Bks): return part_binom_conv(n, 0, n, Aks, Bks) def part_binom_conv(n, k0, k1, Aks, Bks): A_shape = Aks.shape[1:] A_dtype = Aks.dtype init_conv = jnp.zeros(A_shape, dtype=A_dtype) conv = wrapped_fori(k0, k1, update_binom_conv, (init_conv, n, Aks, Bks))[0] #<--- This causes an error return conv def update_binom_conv(k, val): conv, n, Aks, Bks = val conv = conv + comb(n-1, k) * Aks[k] @ Bks[(n-1)-k] return conv, n, Aks, Bks @jax.jit def build(U, Hks): n = Hks.shape[0] # n=0 H_shape = Hks.shape[1:] # H_shape=(2,2) Uks_shape = (n+1,)+H_shape # Uks_shape=(1,2,2) Uks = jnp.zeros(Uks_shape, dtype=Hks.dtype) Uks = Uks.at[0].set(U) Uks = wrapped_fori(0, n, update_Uks, (Uks, Hks))[0] # Treats the case n=0 separately return Uks def update_Uks(k, val): Uks, Hks = val Uks = Uks.at[k+1].set(-1j*binom_conv(k+1, Hks, Uks)) return Uks, Hks # Test Hks = jnp.zeros((3,2,2), dtype=complex) U = jnp.eye(2, dtype=complex) build(U, Hks) The error is a TracerBoolConversionError which I think is related to the tracing the condition in my wrapper:
--------------------------------------------------------------------------- TracerBoolConversionError Traceback (most recent call last) Cell In[4], line 55 53 Hks = jnp.zeros((3,2,2), dtype=complex) 54 U = jnp.eye(2, dtype=complex) ---> 55 build(U, Hks) [... skipping hidden 13 frame] Cell In[4], line 43 41 Uks = jnp.zeros(Uks_shape, dtype=Hks.dtype) 42 Uks = Uks.at[0].set(U) ---> 43 Uks = wrapped_fori(0, n, update_Uks, (Uks, Hks))[0] # Treats the case n=0 separately 44 return Uks Cell In[4], line 10 8 out = init_val 9 else: ---> 10 out = jax.lax.fori_loop(lower, upper, body_fun, init_val, unroll=unroll) 11 return out [... skipping hidden 12 frame] Cell In[4], line 48 46 def update_Uks(k, val): ... -> 1806 raise TracerBoolConversionError(arg) TracerBoolConversionError: Attempted boolean conversion of traced array with shape bool[]. The error occurred while tracing the function update_Uks at /var/folders/x0/28x522xx1vb2xl75tn781lqr0000gn/T/ipykernel_54810/1590930335.py:46 for fori_loop. This concrete value was not available in Python because it depends on the value of the argument k. See https://docs.jax.dev/en/latest/errors.html#jax.errors.TracerBoolConversionError My question is a little bit of a challenge. Is it possible to modify this wrapper for the fori_loop so that it doesn't trace the body when upper<=lower, and that it never causes an error in nested loops?
I understand that this will not be implemented in jax, but I was wondering if it is something I could do in my code.