Skip to content

Commit f5b3f8a

Browse files
committed
Merge pull request #5378 from tacaswell/fix_andrews_colors
BUG: Fixes color selection in andrews_curve
2 parents 3ebd769 + 8e2f3a2 commit f5b3f8a

File tree

2 files changed

+25
-22
lines changed

2 files changed

+25
-22
lines changed

doc/source/release.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -205,6 +205,8 @@ Improvements to existing features
205205
wrapper is updated inplace, a copy is still made internally.
206206
(:issue:`1960`, :issue:`5247`, and related :issue:`2325` [still not
207207
closed])
208+
- Fixed bug in `tools.plotting.andrews_curvres` so that lines are drawn grouped
209+
by color as expected.
208210

209211
API Changes
210212
~~~~~~~~~~~

pandas/tools/plotting.py

Lines changed: 23 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -454,13 +454,14 @@ def f(x):
454454

455455
n = len(data)
456456
class_col = data[class_column]
457+
uniq_class = class_col.drop_duplicates()
457458
columns = [data[col] for col in data.columns if (col != class_column)]
458459
x = [-pi + 2.0 * pi * (t / float(samples)) for t in range(samples)]
459460
used_legends = set([])
460461

461-
colors = _get_standard_colors(num_colors=n, colormap=colormap,
462+
colors = _get_standard_colors(num_colors=len(uniq_class), colormap=colormap,
462463
color_type='random', color=kwds.get('color'))
463-
464+
col_dict = dict([(klass, col) for klass, col in zip(uniq_class, colors)])
464465
if ax is None:
465466
ax = plt.gca(xlim=(-pi, pi))
466467
for i in range(n):
@@ -471,9 +472,9 @@ def f(x):
471472
if com.pprint_thing(class_col[i]) not in used_legends:
472473
label = com.pprint_thing(class_col[i])
473474
used_legends.add(label)
474-
ax.plot(x, y, color=colors[i], label=label, **kwds)
475+
ax.plot(x, y, color=col_dict[class_col[i]], label=label, **kwds)
475476
else:
476-
ax.plot(x, y, color=colors[i], **kwds)
477+
ax.plot(x, y, color=col_dict[class_col[i]], **kwds)
477478

478479
ax.legend(loc='upper right')
479480
ax.grid()
@@ -656,10 +657,10 @@ def lag_plot(series, lag=1, ax=None, **kwds):
656657
ax: Matplotlib axis object
657658
"""
658659
import matplotlib.pyplot as plt
659-
660+
660661
# workaround because `c='b'` is hardcoded in matplotlibs scatter method
661662
kwds.setdefault('c', plt.rcParams['patch.facecolor'])
662-
663+
663664
data = series.values
664665
y1 = data[:-lag]
665666
y2 = data[lag:]
@@ -1212,20 +1213,20 @@ def __init__(self, data, x, y, **kwargs):
12121213
y = self.data.columns[y]
12131214
self.x = x
12141215
self.y = y
1215-
1216-
1216+
1217+
12171218
def _make_plot(self):
12181219
x, y, data = self.x, self.y, self.data
12191220
ax = self.axes[0]
12201221
ax.scatter(data[x].values, data[y].values, **self.kwds)
1221-
1222+
12221223
def _post_plot_logic(self):
12231224
ax = self.axes[0]
1224-
x, y = self.x, self.y
1225+
x, y = self.x, self.y
12251226
ax.set_ylabel(com.pprint_thing(y))
12261227
ax.set_xlabel(com.pprint_thing(x))
1227-
1228-
1228+
1229+
12291230
class LinePlot(MPLPlot):
12301231

12311232
def __init__(self, data, **kwargs):
@@ -1658,25 +1659,25 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
16581659
elif kind == 'kde':
16591660
klass = KdePlot
16601661
elif kind == 'scatter':
1661-
klass = ScatterPlot
1662+
klass = ScatterPlot
16621663
else:
16631664
raise ValueError('Invalid chart type given %s' % kind)
16641665

16651666
if kind == 'scatter':
1666-
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
1667-
rot=rot,legend=legend, ax=ax, style=style,
1667+
plot_obj = klass(frame, x=x, y=y, kind=kind, subplots=subplots,
1668+
rot=rot,legend=legend, ax=ax, style=style,
16681669
fontsize=fontsize, use_index=use_index, sharex=sharex,
1669-
sharey=sharey, xticks=xticks, yticks=yticks,
1670-
xlim=xlim, ylim=ylim, title=title, grid=grid,
1671-
figsize=figsize, logx=logx, logy=logy,
1672-
sort_columns=sort_columns, secondary_y=secondary_y,
1670+
sharey=sharey, xticks=xticks, yticks=yticks,
1671+
xlim=xlim, ylim=ylim, title=title, grid=grid,
1672+
figsize=figsize, logx=logx, logy=logy,
1673+
sort_columns=sort_columns, secondary_y=secondary_y,
16731674
**kwds)
16741675
else:
16751676
if x is not None:
16761677
if com.is_integer(x) and not frame.columns.holds_integer():
16771678
x = frame.columns[x]
16781679
frame = frame.set_index(x)
1679-
1680+
16801681
if y is not None:
16811682
if com.is_integer(y) and not frame.columns.holds_integer():
16821683
y = frame.columns[y]
@@ -1691,7 +1692,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
16911692
grid=grid, logx=logx, logy=logy,
16921693
secondary_y=secondary_y, title=title,
16931694
figsize=figsize, fontsize=fontsize, **kwds)
1694-
1695+
16951696
else:
16961697
plot_obj = klass(frame, kind=kind, subplots=subplots, rot=rot,
16971698
legend=legend, ax=ax, style=style, fontsize=fontsize,
@@ -1700,7 +1701,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
17001701
title=title, grid=grid, figsize=figsize, logx=logx,
17011702
logy=logy, sort_columns=sort_columns,
17021703
secondary_y=secondary_y, **kwds)
1703-
1704+
17041705
plot_obj.generate()
17051706
plot_obj.draw()
17061707
if subplots:

0 commit comments

Comments
 (0)