3

So I know that the numpy argmax retrieves the maximum value along an axis. Thus,

x = np.array([[12,11,10,9],[16,15,14,13],[20,19,18,17]]) print(x) print(x.sum(axis=1)) print(x.sum(axis=0)) 

would output,

[[12 11 10 9] [16 15 14 13] [20 19 18 17]] [42 58 74] [48 45 42 39] 

This makes sense as the sum along axis 1 (row) is [42 58 74] and axis 0 (column) is [48 45 42 39]. However, i am confused of how argmax work. From my understanding, argmax is supposed to return the max number along the axis. Below is my code and output.

Code: print(np.argmax(x,axis=1)). Output: [0 0 0]

Code: print(np.argmax(x,axis=0)). Output: [2 2 2 2]

Where does 0 and 2 come from? I've deliberately used a set of more complex integer values (9..20) so as to distinguish between the 0 and 2 and the integer values inside the array.

4
  • may i know why it was downvoted Commented May 22, 2018 at 6:19
  • what do u mean by index of maximum value Commented May 22, 2018 at 6:20
  • 3
    i googled for it and yet i dont understand. I even tried coding it out to understand how it works. Do u think i will go through the trouble of posting a long post and trying it out without googling? Commented May 22, 2018 at 6:21
  • 1
    np.max returns the maximum values along the respective axis; np.argmax returns 'where' those values occur, the index. Commented May 22, 2018 at 6:26

2 Answers 2

5

np.argmax(x,axis=1) returns the index of maximum of in every row.

axis=1 means "along axis 1", i.e, row.

[[12 11 10 9] <-- max at index 0 [16 15 14 13] <-- max at index 0 [20 19 18 17]] <-- max at index 0 

Thus its output is [0 0 0].

It's similar for np.argmax(x,axis=0), but now it returns the index of maximum of in every column.

Sign up to request clarification or add additional context in comments.

1 Comment

Thank you for the response. It's good to see people replying in kind unlike how the OP was initially down-voted and rebuked by another user.
0

Correction: axis=0 refers to rows, not to columns. axis=1 refers to columns, not to rows.

x = np.array([[12,11,10,9],[16,15,14,13],[20,19,18,17]]) print(x) [[12 11 10 9] [16 15 14 13] [20 19 18 17]] np.argmax(x, axis=0) array([2, 2, 2, 2] # third row, index 2 of each of the 4 columns np.argmax(x, axis=1) array([0, 0, 0] # first column, index 0 of each of the three rows. 

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.