diff --git a/pandas/tests/test_graphics.py b/pandas/tests/test_graphics.py index 49dc31514da7a..2d2218c60119f 100644 --- a/pandas/tests/test_graphics.py +++ b/pandas/tests/test_graphics.py @@ -419,37 +419,6 @@ def test_explicit_label(self): ax = df.plot(x='a', y='b', label='LABEL') self.assertEqual(ax.xaxis.get_label().get_text(), 'LABEL') - @slow - def test_plot_xy(self): - import matplotlib.pyplot as plt - # columns.inferred_type == 'string' - df = tm.makeTimeDataFrame() - self._check_data(df.plot(x=0, y=1), - df.set_index('A')['B'].plot()) - self._check_data(df.plot(x=0), df.set_index('A').plot()) - self._check_data(df.plot(y=0), df.B.plot()) - self._check_data(df.plot(x='A', y='B'), - df.set_index('A').B.plot()) - self._check_data(df.plot(x='A'), df.set_index('A').plot()) - self._check_data(df.plot(y='B'), df.B.plot()) - - # columns.inferred_type == 'integer' - df.columns = lrange(1, len(df.columns) + 1) - self._check_data(df.plot(x=1, y=2), - df.set_index(1)[2].plot()) - self._check_data(df.plot(x=1), df.set_index(1).plot()) - self._check_data(df.plot(y=1), df[1].plot()) - - # figsize and title - ax = df.plot(x=1, y=2, title='Test', figsize=(16, 8)) - - self.assertEqual(ax.title.get_text(), 'Test') - assert_array_equal(np.round(ax.figure.get_size_inches()), - np.array((16., 8.))) - - # columns.inferred_type == 'mixed' - # TODO add MultiIndex test - @slow def test_xcompat(self): import pandas as pd @@ -534,6 +503,27 @@ def test_subplots(self): [self.assert_(label.get_visible()) for label in ax.get_yticklabels()] + @slow + def test_plot_scatter(self): + from matplotlib.pylab import close + df = DataFrame(randn(6, 4), + index=list(string.ascii_letters[:6]), + columns=['x', 'y', 'z', 'four']) + + _check_plot_works(df.plot, x='x', y='y', kind='scatter') + _check_plot_works(df.plot, x='x', y='y', kind='scatter', legend=False) + _check_plot_works(df.plot, x='x', y='y', kind='scatter', subplots=True) + _check_plot_works(df.plot, x='x', y='y', kind='scatter', stacked=True) + + df = DataFrame(randn(10, 15), + index=list(string.ascii_letters[:10]), + columns=lrange(15)) + _check_plot_works(df.plot, x=1, y=2, kind='scatter') + + df = DataFrame({'a': [0, 1], 'b': [1, 0]}) + _check_plot_works(df.plot, x='a',y='b',kind='scatter') + + @slow def test_plot_bar(self): from matplotlib.pylab import close diff --git a/pandas/tools/plotting.py b/pandas/tools/plotting.py index 6631a3cf8c6f1..d76534d285917 100644 --- a/pandas/tools/plotting.py +++ b/pandas/tools/plotting.py @@ -1193,7 +1193,29 @@ def _post_plot_logic(self): for ax in self.axes: ax.legend(loc='best') - +class ScatterPlot(MPLPlot): + def __init__(self, data, **kwargs): + MPLPlot.__init__(self, data, **kwargs) + if 'x' not in kwargs and 'y' not in kwargs: + raise ValueError( 'Scatterplot requires and X and Y column') + + def _make_plot(self): + plotf = self._get_plot_function() + colors = self._get_colors() + + for i, (label, y) in enumerate(self._iter_data()): + ax = self._get_ax(i) + #kwds = self.kwds.copy() + x, y = self.kwds['x'], self.kwds['y'] + #print x, y + ax = ax.scatter(x, y) + style = self._get_style(i, label) + + def _post_plot_logic(self): + if self.subplots and self.legend: + for ax in self.axes: + ax.legend(loc='best') + class LinePlot(MPLPlot): def __init__(self, data, **kwargs): @@ -1554,7 +1576,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, secondary_y=False, **kwds): """ - Make line or bar plot of DataFrame's series with the index on the x-axis + Make line, bar, or scater plots of DataFrame series with the index on the x-axis using matplotlib / pylab. Parameters @@ -1585,10 +1607,11 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, ax : matplotlib axis object, default None style : list or dict matplotlib line style per column - kind : {'line', 'bar', 'barh', 'kde', 'density'} + kind : {'line', 'bar', 'barh', 'kde', 'density', 'scatter'} bar : vertical bar plot barh : horizontal bar plot kde/density : Kernel Density Estimation plot + scatter: scatter plot logx : boolean, default False For line plots, use log scaling on x axis logy : boolean, default False @@ -1624,6 +1647,8 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, klass = BarPlot elif kind == 'kde': klass = KdePlot + elif kind == 'scatter': + klass = ScatterPlot else: raise ValueError('Invalid chart type given %s' % kind) @@ -1639,21 +1664,35 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True, label = kwds.pop('label', label) ser = frame[y] ser.index.name = label - return plot_series(ser, label=label, kind=kind, - use_index=use_index, - rot=rot, xticks=xticks, yticks=yticks, - xlim=xlim, ylim=ylim, ax=ax, style=style, - grid=grid, logx=logx, logy=logy, - secondary_y=secondary_y, title=title, - figsize=figsize, fontsize=fontsize, **kwds) - - plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, - legend=legend, ax=ax, style=style, fontsize=fontsize, - use_index=use_index, sharex=sharex, sharey=sharey, - xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, - title=title, grid=grid, figsize=figsize, logx=logx, - logy=logy, sort_columns=sort_columns, - secondary_y=secondary_y, **kwds) + if kind != 'scatter': + return plot_series(ser, label=label, kind=kind, + use_index=use_index, + rot=rot, xticks=xticks, yticks=yticks, + xlim=xlim, ylim=ylim, ax=ax, style=style, + grid=grid, logx=logx, logy=logy, + secondary_y=secondary_y, title=title, + figsize=figsize, fontsize=fontsize, **kwds) + if kind == 'scatter': + plot_obj = klass(frame, x=frame.index, y=ser, + kind=kind, subplots=subplots, rot=rot, + legend=legend, ax=ax, style=style, fontsize=fontsize, + use_index=use_index, sharex=sharex, sharey=sharey, + xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, + title=title, grid=grid, figsize=figsize, logx=logx, + logy=logy, sort_columns=sort_columns, + secondary_y=secondary_y, **kwds) + + else: + plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot, + legend=legend, ax=ax, style=style, fontsize=fontsize, + use_index=use_index, sharex=sharex, sharey=sharey, + xticks=xticks, yticks=yticks, xlim=xlim, ylim=ylim, + title=title, grid=grid, figsize=figsize, logx=logx, + logy=logy, sort_columns=sort_columns, + secondary_y=secondary_y, **kwds) + + + plot_obj.generate() plot_obj.draw() if subplots: