The most straightforward way to do this is probably to create an array of zeros, and set a random index to 1. In NumPy, it might look like this:
import numpy as np K, M, N = 5, 3, 2 i = np.random.randint(0, M, K) j = np.random.randint(0, N, K) x = np.zeros((K, M, N)) x[np.arange(K), i, j] = 1
In JAX, it might look something like this:
import jax.numpy as jnp from jax import random K, M, N = 5, 3, 2 key1, key2 = random.split(random.PRNGKey(0)) i = random.randint(key1, (K,), 0, M) j = random.randint(key2, (K,), 0, N) x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1)
A more concise option that also guarantees a single 1 per slice would be to use broadcasted equality of a random integer with an appropriately constructed range:
r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N) x = (r == jnp.arange(M * N).reshape(M, N)).astype(int)
t.sum(axis=(1,2)) == 1?