56

In numpy.argmax function, tie breaking between multiple max elements is so that the first element is returned. Is there a functionality for randomizing tie breaking so that all maximum numbers have equal chance of being selected?

Below is an example directly from numpy.argmax documentation.

>>> b = np.arange(6) >>> b[1] = 5 >>> b array([0, 5, 2, 3, 4, 5]) >>> np.argmax(b) # Only the first occurrence is returned. 1 

I am looking for ways so that 1st and 5th elements in the list are returned with equal probability.

Thank you!

0

6 Answers 6

61

Use np.random.choice -

np.random.choice(np.flatnonzero(b == b.max())) 

Let's verify for an array with three max candidates -

In [298]: b Out[298]: array([0, 5, 2, 5, 4, 5]) In [299]: c=[np.random.choice(np.flatnonzero(b == b.max())) for i in range(100000)] In [300]: np.bincount(c) Out[300]: array([ 0, 33180, 0, 33611, 0, 33209]) 
Sign up to request clarification or add additional context in comments.

4 Comments

if you have floats instead of ints, you may want to replace b == b.max() with np.isclose(b, b.max())
arr.max() can be replaced by np.max(arr) for clarity that numpy array is involved
roughly three times slower!
For better performance you could try max(enumerate(b), key=lambda key: (key[1], rng.random()))[0]
16

In the case of a multi-dimensional array, choice won't work.

An alternative is

def randargmax(b,**kw): """ a random tie-breaking argmax""" return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw) 

If for some reason generating random floats is slower than some other method, random.random can be replaced with that other method.

3 Comments

roughly 18 times slower!
**kw should also be passed to max() (e.g., axis), so it's correctly calculated (e.g., the max per row instead of the global max).
Also, there's a chance the random value is 0. I think doing a +1 to the random value fixes this unlikely case.
9

Easiest way is

np.random.choice(np.where(b == b.max())[0]) 

1 Comment

Please provide some explanation along with the answer.
8

Since the accepted answer may not be obvious, here is how it works:

  • b == b.max() will return an array of booleans, with values of true where items are max and values of false for other items
  • flatnonzero() will do two things: ignore the false values (nonzero part) then return indices of true values. In other words, you get an array with indices of items matching the max value
  • Finally, you pick a random index from the array

Comments

2

Additional to @Manux's answer,

Changing b.max() to np.amax(b,**kw, keepdims=True) will let you do it along axes.

def randargmax(b,**kw): """ a random tie-breaking argmax""" return np.argmax(np.random.random(b.shape) * (b==b.max()), **kw) randargmax(b,axis=None) 

Comments

2

Here is a comparison between the two main solutions by @divakar and @shyam-padia :

method (1) - using np.where

np.random.choice(np.where(b == b.max())[0]) 

method (2) - using np.flatnonzero

np.random.choice(np.flatnonzero(b == b.max()) 

Code

Here is the code I wrote for the comparison:

def method1(b, bmax,): return np.random.choice(np.where(b == bmax)[0]) def method2(b, bmax): return np.random.choice(np.flatnonzero(b == bmax)) def time_it(n): b = np.array([1.0, 2.0, 5.0, 5.0, 0.4, 0.1, 5.0, 0.3, 0.1]) bmax = b.max() start = time.perf_counter() for i in range(n): method1(b, bmax) elapsed1 = time.perf_counter() - start start = time.perf_counter() for i in range(n): method2(b, bmax) elapsed2 = time.perf_counter() - start print(f'method1 time: {elapsed1} - method2 time: {elapsed2}') return elapsed1, elapsed2 

Results

The following figure shows the computation time for running each method for [100, 1000, 10000, 100000, 1000000] iterations where x-axis represents number of iterations, y-axis shows time in seconds. It can be seen that np.where performs better than np.flatnonzero when number of iterations increases. Note that the x-axis has a logarithmic scale.

enter image description here

To show how the two methods compare in the lower iteration, we can re-plot the previous results by making the y-axis being a logarithmic scale. We can see that np.where stays always better than np.flatnonzero.

enter image description here

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.