67

Example of scatterplot matrix

enter image description here

Is there such a function in matplotlib.pyplot?

1

5 Answers 5

126

For those who do not want to define their own functions, there is a great data analysis libarary in Python, called Pandas, where one can find the scatter_matrix() method:

from pandas.plotting import scatter_matrix df = pd.DataFrame(np.random.randn(1000, 4), columns = ['a', 'b', 'c', 'd']) scatter_matrix(df, alpha = 0.2, figsize = (6, 6), diagonal = 'kde') 

enter image description here

Sign up to request clarification or add additional context in comments.

5 Comments

Hi, how come only part of the subplots have a grid in them? Can that be modified (either all or none)? Thanks
+1 That'll teach me to go searching for a Python feature before looking to see if it's already in pandas. Step 1: Always ask, does it already exist in pandas? pd.scatter_matrix(df); plt.show(). Incredible.
Placing a kde in the matplotlib scatterplot matrix is extreme sport. I love pandas.
Does anyone know where the actual API documentation for pd.tools.plotting.scatter_matrix is? Everywhere that I look I can only find that one example - I can't find the optional arguments for the life of me...
As of pandas 0.20, scatter_matrix has been moved to pandas.plotting.scatter_matrix.
33

Generally speaking, matplotlib doesn't usually contain plotting functions that operate on more than one axes object (subplot, in this case). The expectation is that you'd write a simple function to string things together however you'd like.

I'm not quite sure what your data looks like, but it's quite simple to just build a function to do this from scratch. If you're always going to be working with structured or rec arrays, then you can simplify this a touch. (i.e. There's always a name associated with each data series, so you can omit having to specify names.)

As an example:

import itertools import numpy as np import matplotlib.pyplot as plt def main(): np.random.seed(1977) numvars, numdata = 4, 10 data = 10 * np.random.random((numvars, numdata)) fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'], linestyle='none', marker='o', color='black', mfc='none') fig.suptitle('Simple Scatterplot Matrix') plt.show() def scatterplot_matrix(data, names, **kwargs): """Plots a scatterplot matrix of subplots. Each row of "data" is plotted against other rows, resulting in a nrows by nrows grid of subplots with the diagonal subplots labeled with "names". Additional keyword arguments are passed on to matplotlib's "plot" command. Returns the matplotlib figure object containg the subplot grid.""" numvars, numdata = data.shape fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8)) fig.subplots_adjust(hspace=0.05, wspace=0.05) for ax in axes.flat: # Hide all ticks and labels ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) # Set up ticks only on one side for the "edge" subplots... if ax.is_first_col(): ax.yaxis.set_ticks_position('left') if ax.is_last_col(): ax.yaxis.set_ticks_position('right') if ax.is_first_row(): ax.xaxis.set_ticks_position('top') if ax.is_last_row(): ax.xaxis.set_ticks_position('bottom') # Plot the data. for i, j in zip(*np.triu_indices_from(axes, k=1)): for x, y in [(i,j), (j,i)]: axes[x,y].plot(data[x], data[y], **kwargs) # Label the diagonal subplots... for i, label in enumerate(names): axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction', ha='center', va='center') # Turn on the proper x or y axes ticks. for i, j in zip(range(numvars), itertools.cycle((-1, 0))): axes[j,i].xaxis.set_visible(True) axes[i,j].yaxis.set_visible(True) return fig main() 

enter image description here

6 Comments

Wow, many new functions! Yes, not too difficult when you have mastery of the module... but not as simple as calling pairs as in R. :)
True! R has a lot more specialized functions, in my (limited!) experience with it. Matplotlib has a slightly more DIY approach. (Or certainly a lot fewer specialized statistical plotting functions, at any rate.)
Certainly I feel this way. I'm sticking with the Python trio (for now) in hopes though that it offers other advantages...
In my opinion, the big advantage is python's flexibility. R is a fantastic domain specific language, and if you're just wanting to do statistical analysis, it's unmatched. Python is a nice general programming language, and you'll really start to see the benefits with larger programs. Once you begin to want a program with an interactive gui that grabs data from the web, parses some random binary file format, does your analysis, and plots it all up, a general programming language beings to show a lot of advantages. Of course, that's true for a lot of languages, but I prefer python. :)
@Joe Kington, firstly, thanks for this example (I use it regularly) and all your other mpl examples! A couple of points: 1. For those wishing to match R, the x and y values are backwards: change plot axes[x,y] to axes[y,x]. 2. set sharex='col', sharey='row' in subplots() 3. diagonal affects the tick limits, so either set the limits or plot axes[i,i].plot(data[i], data[i], linestyle='None') 4. if data is in row, col format, then input must be transposed, data.T
|
18

You can also use Seaborn's pairplot function:

import seaborn as sns sns.set() df = sns.load_dataset("iris") sns.pairplot(df, hue="species") 

2 Comments

the annoying part about seaborn is that it's centered around pandas DataFrames. If you have a NumPy array, this workaround feels annoying, and if you already have a pandas DataFrame, why not just using pandas' in-build scatter_matrix method?
Unfortunately, it does not allow scatterplot matrices formed by two distinct groups of variables. It just gives vars vs vars plot. This complicates analysis for medium-sized and large datasets.
11

Thanks for sharing your code! You figured out all the hard stuff for us. As I was working with it, I noticed a few little things that didn't look quite right.

  1. [FIX #1] The axis tics weren't lining up like I would expect (i.e., in your example above, you should be able to draw a vertical and horizontal line through any point across all plots and the lines should cross through the corresponding point in the other plots, but as it sits now this doesn't occur.

  2. [FIX #2] If you have an odd number of variables you are plotting with, the bottom right corner axes doesn't pull the correct xtics or ytics. It just leaves it as the default 0..1 ticks.

  3. Not a fix, but I made it optional to explicitly input names, so that it puts a default xi for variable i in the diagonal positions.

Below you'll find an updated version of your code that addresses these two points, otherwise preserving the beauty of your code.

import itertools import numpy as np import matplotlib.pyplot as plt def scatterplot_matrix(data, names=[], **kwargs): """ Plots a scatterplot matrix of subplots. Each row of "data" is plotted against other rows, resulting in a nrows by nrows grid of subplots with the diagonal subplots labeled with "names". Additional keyword arguments are passed on to matplotlib's "plot" command. Returns the matplotlib figure object containg the subplot grid. """ numvars, numdata = data.shape fig, axes = plt.subplots(nrows=numvars, ncols=numvars, figsize=(8,8)) fig.subplots_adjust(hspace=0.0, wspace=0.0) for ax in axes.flat: # Hide all ticks and labels ax.xaxis.set_visible(False) ax.yaxis.set_visible(False) # Set up ticks only on one side for the "edge" subplots... if ax.is_first_col(): ax.yaxis.set_ticks_position('left') if ax.is_last_col(): ax.yaxis.set_ticks_position('right') if ax.is_first_row(): ax.xaxis.set_ticks_position('top') if ax.is_last_row(): ax.xaxis.set_ticks_position('bottom') # Plot the data. for i, j in zip(*np.triu_indices_from(axes, k=1)): for x, y in [(i,j), (j,i)]: # FIX #1: this needed to be changed from ...(data[x], data[y],...) axes[x,y].plot(data[y], data[x], **kwargs) # Label the diagonal subplots... if not names: names = ['x'+str(i) for i in range(numvars)] for i, label in enumerate(names): axes[i,i].annotate(label, (0.5, 0.5), xycoords='axes fraction', ha='center', va='center') # Turn on the proper x or y axes ticks. for i, j in zip(range(numvars), itertools.cycle((-1, 0))): axes[j,i].xaxis.set_visible(True) axes[i,j].yaxis.set_visible(True) # FIX #2: if numvars is odd, the bottom right corner plot doesn't have the # correct axes limits, so we pull them from other axes if numvars%2: xlimits = axes[0,-1].get_xlim() ylimits = axes[-1,0].get_ylim() axes[-1,-1].set_xlim(xlimits) axes[-1,-1].set_ylim(ylimits) return fig if __name__=='__main__': np.random.seed(1977) numvars, numdata = 4, 10 data = 10 * np.random.random((numvars, numdata)) fig = scatterplot_matrix(data, ['mpg', 'disp', 'drat', 'wt'], linestyle='none', marker='o', color='black', mfc='none') fig.suptitle('Simple Scatterplot Matrix') plt.show() 

Thanks again for sharing this with us. I have used it many times! Oh, and I re-arranged the main() part of the code so that it can be a formal example code or not get called if it is being imported into another piece of code.

2 Comments

Thanks, I was having the problems with @Joe Kington's code until I saw your answer. It saved me some debugging time :)
Any idea, how can I make this function faster, I need to generate a big scatter plot matrix around 100 vars and this method is very slow.
6

While reading the question I expected to see an answer including rpy. I think this is a nice option taking advantage of two beautiful languages. So here it is:

import rpy import numpy as np def main(): np.random.seed(1977) numvars, numdata = 4, 10 data = 10 * np.random.random((numvars, numdata)) mpg = data[0,:] disp = data[1,:] drat = data[2,:] wt = data[3,:] rpy.set_default_mode(rpy.NO_CONVERSION) R_data = rpy.r.data_frame(mpg=mpg,disp=disp,drat=drat,wt=wt) # Figure saved as eps rpy.r.postscript('pairsPlot.eps') rpy.r.pairs(R_data, main="Simple Scatterplot Matrix Via RPy") rpy.r.dev_off() # Figure saved as png rpy.r.png('pairsPlot.png') rpy.r.pairs(R_data, main="Simple Scatterplot Matrix Via RPy") rpy.r.dev_off() rpy.set_default_mode(rpy.BASIC_CONVERSION) if __name__ == '__main__': main() 

I can't post an image to show the result :( sorry!

Comments

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.