5

I need to create a 3D tensor like this (5,3,2) for example

array([[[0, 0], [0, 1], [0, 0]], [[1, 0], [0, 0], [0, 0]], [[0, 0], [1, 0], [0, 0]], [[0, 0], [0, 0], [1, 0]], [[0, 0], [0, 1], [0, 0]]]) 

There should be exactly one 'one' placed randomly in every slice (if you consider the tensor to be a loaf of bread). This could be done using loops, but I want to vectorize this part.

5
  • t.sum(axis=(1,2)) == 1? Commented Feb 16, 2021 at 3:05
  • 1
    well yes. but i want to randomly generate t. @QuangHoang Commented Feb 16, 2021 at 3:07
  • 1
    See some of the techniques discussed here: stackoverflow.com/questions/19597473/… Commented Feb 16, 2021 at 3:08
  • @blorgon I went through those. none can satisfy the one per slice condition Commented Feb 16, 2021 at 3:10
  • You could certainly apply the techniques in that post, but @QuangHoang 's answer is quite clever. Commented Feb 16, 2021 at 3:19

3 Answers 3

4

Try generate a random array, then find the max:

a = np.random.rand(5,3,2) out = (a == a.max(axis=(1,2))[:,None,None]).astype(int) 
Sign up to request clarification or add additional context in comments.

1 Comment

This is not a great approach, because there is a small but nonzero chance that the maximum will appear more than once, leading to multiple 1s in a slice.
3

The most straightforward way to do this is probably to create an array of zeros, and set a random index to 1. In NumPy, it might look like this:

import numpy as np K, M, N = 5, 3, 2 i = np.random.randint(0, M, K) j = np.random.randint(0, N, K) x = np.zeros((K, M, N)) x[np.arange(K), i, j] = 1 

In JAX, it might look something like this:

import jax.numpy as jnp from jax import random K, M, N = 5, 3, 2 key1, key2 = random.split(random.PRNGKey(0)) i = random.randint(key1, (K,), 0, M) j = random.randint(key2, (K,), 0, N) x = jnp.zeros((K, M, N)).at[jnp.arange(K), i, j].set(1) 

A more concise option that also guarantees a single 1 per slice would be to use broadcasted equality of a random integer with an appropriately constructed range:

r = random.randint(random.PRNGKey(0), (K, 1, 1), 0, M * N) x = (r == jnp.arange(M * N).reshape(M, N)).astype(int) 

Comments

0

You can create a zero array where the first element of each sub-array is 1, and then permute it across the final two axes:

x = np.zeros((5,3,2)); x[:,0,0] = 1 rng = np.random.default_rng() x = rng.permuted(rng.permuted(x, axis=-1), axis=-2) 

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.