Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions doc/source/whatsnew/v0.16.1.txt
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ Bug Fixes

- Bug in ``DataFrame.plot(kind="hist")`` results in ``TypeError`` when ``DataFrame`` contains non-numeric columns (:issue:`9853`)
- Bug where repeated plotting of ``DataFrame`` with a ``DatetimeIndex`` may raise ``TypeError`` (:issue:`9852`)
- Bug in plotting ``secondary_y`` incorrectly attaches ``right_ax`` property to secondary axes specifying itself recursively. (:issue:`9861`)

- Bug in ``Series.quantile`` on empty Series of type ``Datetime`` or ``Timedelta`` (:issue:`9675`)
- Bug in ``where`` causing incorrect results when upcasting was required (:issue:`9731`)
Expand Down
3 changes: 3 additions & 0 deletions pandas/tests/test_graphics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1609,7 +1609,10 @@ def test_line_lim(self):
self.assertEqual(xmax, lines[0].get_data()[0][-1])

axes = df.plot(secondary_y=True, subplots=True)
self._check_axes_shape(axes, axes_num=3, layout=(3, 1))
for ax in axes:
self.assertTrue(hasattr(ax, 'left_ax'))
self.assertFalse(hasattr(ax, 'right_ax'))
xmin, xmax = ax.get_xlim()
lines = ax.get_lines()
self.assertEqual(xmin, lines[0].get_data()[0][0])
Expand Down
33 changes: 19 additions & 14 deletions pandas/tools/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -927,19 +927,21 @@ def _has_plotted_object(self, ax):

def _maybe_right_yaxis(self, ax, axes_num):
if not self.on_right(axes_num):
if hasattr(ax, 'left_ax'):
# secondary axes may be passed as axes
return ax.left_ax
return ax
# secondary axes may be passed via ax kw
return self._get_ax_layer(ax)

if hasattr(ax, 'right_ax'):
# if it has right_ax proparty, ``ax`` must be left axes
return ax.right_ax
elif hasattr(ax, 'left_ax'):
# if it has left_ax proparty, ``ax`` must be right axes
return ax
else:
# otherwise, create twin axes
orig_ax, new_ax = ax, ax.twinx()
new_ax._get_lines.color_cycle = orig_ax._get_lines.color_cycle

orig_ax.right_ax, new_ax.left_ax = new_ax, orig_ax
new_ax.right_ax = new_ax

if not self._has_plotted_object(orig_ax): # no data on left y
orig_ax.get_yaxis().set_visible(False)
Expand Down Expand Up @@ -987,9 +989,8 @@ def result(self):
all_sec = (com.is_list_like(self.secondary_y) and
len(self.secondary_y) == self.nseries)
if (sec_true or all_sec):
# if all data is plotted on secondary,
# return secondary axes
return self.axes[0].right_ax
# if all data is plotted on secondary, return right axes
return self._get_ax_layer(self.axes[0], primary=False)
else:
return self.axes[0]

Expand Down Expand Up @@ -1229,11 +1230,18 @@ def _get_index_name(self):

return name

@classmethod
def _get_ax_layer(cls, ax, primary=True):
"""get left (primary) or right (secondary) axes"""
if primary:
return getattr(ax, 'left_ax', ax)
else:
return getattr(ax, 'right_ax', ax)

def _get_ax(self, i):
# get the twinx ax if appropriate
if self.subplots:
ax = self.axes[i]

ax = self._maybe_right_yaxis(ax, i)
self.axes[i] = ax
else:
Expand Down Expand Up @@ -2500,8 +2508,7 @@ def plot_series(data, kind='line', ax=None, # Series unique
"""
if ax is None and len(plt.get_fignums()) > 0:
ax = _gca()
ax = getattr(ax, 'left_ax', ax)

ax = MPLPlot._get_ax_layer(ax)
return _plot(data, kind=kind, ax=ax,
figsize=figsize, use_index=use_index, title=title,
grid=grid, legend=legend,
Expand Down Expand Up @@ -3348,11 +3355,9 @@ def _flatten(axes):
def _get_all_lines(ax):
lines = ax.get_lines()

# check for right_ax, which can oddly sometimes point back to ax
if hasattr(ax, 'right_ax') and ax.right_ax != ax:
if hasattr(ax, 'right_ax'):
lines += ax.right_ax.get_lines()

# no such risk with left_ax
if hasattr(ax, 'left_ax'):
lines += ax.left_ax.get_lines()

Expand Down
22 changes: 17 additions & 5 deletions pandas/tseries/tests/test_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -528,7 +528,9 @@ def test_secondary_y(self):

ser = Series(np.random.randn(10))
ser2 = Series(np.random.randn(10))
ax = ser.plot(secondary_y=True).right_ax
ax = ser.plot(secondary_y=True)
self.assertTrue(hasattr(ax, 'left_ax'))
self.assertFalse(hasattr(ax, 'right_ax'))
fig = ax.get_figure()
axes = fig.get_axes()
l = ax.get_lines()[0]
Expand All @@ -543,16 +545,22 @@ def test_secondary_y(self):
plt.close(ax2.get_figure())

ax = ser2.plot()
ax2 = ser.plot(secondary_y=True).right_ax
ax2 = ser.plot(secondary_y=True)
self.assertTrue(ax.get_yaxis().get_visible())
self.assertFalse(hasattr(ax, 'left_ax'))
self.assertTrue(hasattr(ax, 'right_ax'))
self.assertTrue(hasattr(ax2, 'left_ax'))
self.assertFalse(hasattr(ax2, 'right_ax'))

@slow
def test_secondary_y_ts(self):
import matplotlib.pyplot as plt
idx = date_range('1/1/2000', periods=10)
ser = Series(np.random.randn(10), idx)
ser2 = Series(np.random.randn(10), idx)
ax = ser.plot(secondary_y=True).right_ax
ax = ser.plot(secondary_y=True)
self.assertTrue(hasattr(ax, 'left_ax'))
self.assertFalse(hasattr(ax, 'right_ax'))
fig = ax.get_figure()
axes = fig.get_axes()
l = ax.get_lines()[0]
Expand All @@ -577,7 +585,9 @@ def test_secondary_kde(self):

import matplotlib.pyplot as plt
ser = Series(np.random.randn(10))
ax = ser.plot(secondary_y=True, kind='density').right_ax
ax = ser.plot(secondary_y=True, kind='density')
self.assertTrue(hasattr(ax, 'left_ax'))
self.assertFalse(hasattr(ax, 'right_ax'))
fig = ax.get_figure()
axes = fig.get_axes()
self.assertEqual(axes[1].get_yaxis().get_ticks_position(), 'right')
Expand Down Expand Up @@ -922,7 +932,9 @@ def test_secondary_upsample(self):
ax = high.plot(secondary_y=True)
for l in ax.get_lines():
self.assertEqual(PeriodIndex(l.get_xdata()).freq, 'D')
for l in ax.right_ax.get_lines():
self.assertTrue(hasattr(ax, 'left_ax'))
self.assertFalse(hasattr(ax, 'right_ax'))
for l in ax.left_ax.get_lines():
self.assertEqual(PeriodIndex(l.get_xdata()).freq, 'D')

@slow
Expand Down