@@ -454,13 +454,14 @@ def f(x):
454
454
455
455
n = len (data )
456
456
class_col = data [class_column ]
457
+ uniq_class = class_col .drop_duplicates ()
457
458
columns = [data [col ] for col in data .columns if (col != class_column )]
458
459
x = [- pi + 2.0 * pi * (t / float (samples )) for t in range (samples )]
459
460
used_legends = set ([])
460
461
461
- colors = _get_standard_colors (num_colors = n , colormap = colormap ,
462
+ colors = _get_standard_colors (num_colors = len ( uniq_class ) , colormap = colormap ,
462
463
color_type = 'random' , color = kwds .get ('color' ))
463
-
464
+ col_dict = dict ([( klass , col ) for klass , col in zip ( uniq_class , colors )])
464
465
if ax is None :
465
466
ax = plt .gca (xlim = (- pi , pi ))
466
467
for i in range (n ):
@@ -471,9 +472,9 @@ def f(x):
471
472
if com .pprint_thing (class_col [i ]) not in used_legends :
472
473
label = com .pprint_thing (class_col [i ])
473
474
used_legends .add (label )
474
- ax .plot (x , y , color = colors [ i ], label = label , ** kwds )
475
+ ax .plot (x , y , color = col_dict [ class_col [ i ] ], label = label , ** kwds )
475
476
else :
476
- ax .plot (x , y , color = colors [ i ], ** kwds )
477
+ ax .plot (x , y , color = col_dict [ class_col [ i ] ], ** kwds )
477
478
478
479
ax .legend (loc = 'upper right' )
479
480
ax .grid ()
@@ -656,10 +657,10 @@ def lag_plot(series, lag=1, ax=None, **kwds):
656
657
ax: Matplotlib axis object
657
658
"""
658
659
import matplotlib .pyplot as plt
659
-
660
+
660
661
# workaround because `c='b'` is hardcoded in matplotlibs scatter method
661
662
kwds .setdefault ('c' , plt .rcParams ['patch.facecolor' ])
662
-
663
+
663
664
data = series .values
664
665
y1 = data [:- lag ]
665
666
y2 = data [lag :]
@@ -1212,20 +1213,20 @@ def __init__(self, data, x, y, **kwargs):
1212
1213
y = self .data .columns [y ]
1213
1214
self .x = x
1214
1215
self .y = y
1215
-
1216
-
1216
+
1217
+
1217
1218
def _make_plot (self ):
1218
1219
x , y , data = self .x , self .y , self .data
1219
1220
ax = self .axes [0 ]
1220
1221
ax .scatter (data [x ].values , data [y ].values , ** self .kwds )
1221
-
1222
+
1222
1223
def _post_plot_logic (self ):
1223
1224
ax = self .axes [0 ]
1224
- x , y = self .x , self .y
1225
+ x , y = self .x , self .y
1225
1226
ax .set_ylabel (com .pprint_thing (y ))
1226
1227
ax .set_xlabel (com .pprint_thing (x ))
1227
-
1228
-
1228
+
1229
+
1229
1230
class LinePlot (MPLPlot ):
1230
1231
1231
1232
def __init__ (self , data , ** kwargs ):
@@ -1658,25 +1659,25 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1658
1659
elif kind == 'kde' :
1659
1660
klass = KdePlot
1660
1661
elif kind == 'scatter' :
1661
- klass = ScatterPlot
1662
+ klass = ScatterPlot
1662
1663
else :
1663
1664
raise ValueError ('Invalid chart type given %s' % kind )
1664
1665
1665
1666
if kind == 'scatter' :
1666
- plot_obj = klass (frame , x = x , y = y , kind = kind , subplots = subplots ,
1667
- rot = rot ,legend = legend , ax = ax , style = style ,
1667
+ plot_obj = klass (frame , x = x , y = y , kind = kind , subplots = subplots ,
1668
+ rot = rot ,legend = legend , ax = ax , style = style ,
1668
1669
fontsize = fontsize , use_index = use_index , sharex = sharex ,
1669
- sharey = sharey , xticks = xticks , yticks = yticks ,
1670
- xlim = xlim , ylim = ylim , title = title , grid = grid ,
1671
- figsize = figsize , logx = logx , logy = logy ,
1672
- sort_columns = sort_columns , secondary_y = secondary_y ,
1670
+ sharey = sharey , xticks = xticks , yticks = yticks ,
1671
+ xlim = xlim , ylim = ylim , title = title , grid = grid ,
1672
+ figsize = figsize , logx = logx , logy = logy ,
1673
+ sort_columns = sort_columns , secondary_y = secondary_y ,
1673
1674
** kwds )
1674
1675
else :
1675
1676
if x is not None :
1676
1677
if com .is_integer (x ) and not frame .columns .holds_integer ():
1677
1678
x = frame .columns [x ]
1678
1679
frame = frame .set_index (x )
1679
-
1680
+
1680
1681
if y is not None :
1681
1682
if com .is_integer (y ) and not frame .columns .holds_integer ():
1682
1683
y = frame .columns [y ]
@@ -1691,7 +1692,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1691
1692
grid = grid , logx = logx , logy = logy ,
1692
1693
secondary_y = secondary_y , title = title ,
1693
1694
figsize = figsize , fontsize = fontsize , ** kwds )
1694
-
1695
+
1695
1696
else :
1696
1697
plot_obj = klass (frame , kind = kind , subplots = subplots , rot = rot ,
1697
1698
legend = legend , ax = ax , style = style , fontsize = fontsize ,
@@ -1700,7 +1701,7 @@ def plot_frame(frame=None, x=None, y=None, subplots=False, sharex=True,
1700
1701
title = title , grid = grid , figsize = figsize , logx = logx ,
1701
1702
logy = logy , sort_columns = sort_columns ,
1702
1703
secondary_y = secondary_y , ** kwds )
1703
-
1704
+
1704
1705
plot_obj .generate ()
1705
1706
plot_obj .draw ()
1706
1707
if subplots :
0 commit comments