@@ -53,9 +53,11 @@ class MPLPlot(object):
53
53
data :
54
54
55
55
"""
56
+ _default_rot = 0
57
+
56
58
def __init__ (self , data , kind = None , by = None , subplots = False , sharex = True ,
57
59
sharey = False , use_index = True ,
58
- figsize = None , grid = True , legend = True , rot = 30 ,
60
+ figsize = None , grid = True , legend = True , rot = None ,
59
61
ax = None , fig = None , title = None ,
60
62
xlim = None , ylim = None ,
61
63
xticks = None , yticks = None ,
@@ -97,6 +99,7 @@ def __init__(self, data, kind=None, by=None, subplots=False, sharex=True,
97
99
self ._setup_subplots ()
98
100
self ._make_plot ()
99
101
self ._post_plot_logic ()
102
+ self ._adorn_subplots ()
100
103
101
104
def draw (self ):
102
105
self .plt .draw_if_interactive ()
@@ -235,12 +238,17 @@ def _post_plot_logic(self):
235
238
236
239
237
240
class BarPlot (MPLPlot ):
241
+ _default_rot = {'bar' : 90 , 'barh' : 0 }
238
242
239
243
def __init__ (self , data , ** kwargs ):
240
244
self .stacked = kwargs .pop ('stacked' , False )
245
+ self .ax_pos = np .arange (len (data )) + 0.25
241
246
MPLPlot .__init__ (self , data , ** kwargs )
242
247
243
248
def _args_adjust (self ):
249
+ if self .rot is None :
250
+ self .rot = self ._default_rot [self .kind ]
251
+
244
252
if self .fontsize is None :
245
253
if len (self .data ) < 10 :
246
254
self .fontsize = 12
@@ -264,7 +272,7 @@ def _make_plot(self):
264
272
df = self .data
265
273
266
274
N , K = df .shape
267
- xinds = np . arange ( N ) + 0.25
275
+
268
276
colors = 'rgbyk'
269
277
rects = []
270
278
labels = []
@@ -273,47 +281,62 @@ def _make_plot(self):
273
281
274
282
bar_f = self .bar_f
275
283
276
- prior = np .zeros (N )
284
+ pos_prior = neg_prior = np .zeros (N )
277
285
for i , col in enumerate (df .columns ):
278
286
empty = df [col ].count () == 0
279
287
y = df [col ].values if not empty else np .zeros (len (df ))
280
288
281
289
if self .subplots :
282
290
ax = self .axes [i ]
283
- rect = bar_f (ax , xinds , y , 0.5 , start = prior ,
291
+ rect = bar_f (ax , self . ax_pos , y , 0.5 , start = pos_prior ,
284
292
linewidth = 1 , ** self .kwds )
285
293
ax .set_title (col )
286
294
elif self .stacked :
287
- rect = bar_f (ax , xinds , y , 0.5 , start = prior ,
295
+ mask = y > 0
296
+ start = np .where (mask , pos_prior , neg_prior )
297
+ rect = bar_f (ax , self .ax_pos , y , 0.5 , start = start ,
288
298
color = colors [i % len (colors )],
289
299
label = str (col ), linewidth = 1 ,
290
300
** self .kwds )
291
- prior = y + prior
301
+ pos_prior = pos_prior + np .where (mask , y , 0 )
302
+ neg_prior = neg_prior + np .where (mask , 0 , y )
292
303
else :
293
- rect = bar_f (ax , xinds + i * 0.75 / K , y , 0.75 / K ,
304
+ rect = bar_f (ax , self . ax_pos + i * 0.75 / K , y , 0.75 / K ,
294
305
start = np .zeros (N ), label = str (col ),
295
306
color = colors [i % len (colors )],
296
307
** self .kwds )
297
308
rects .append (rect )
298
309
labels .append (col )
299
310
300
- ax .set_xlim ([xinds [0 ] - 0.25 , xinds [- 1 ] + 1 ])
301
- ax .set_xticks (xinds + 0.375 )
302
- ax .set_xticklabels ([_stringify (key ) for key in df .index ],
303
- rotation = self .rot ,
304
- fontsize = self .fontsize )
305
-
306
311
if self .legend and not self .subplots :
307
312
patches = [r [0 ] for r in rects ]
308
313
309
314
# Legend to the right of the plot
310
315
# ax.legend(patches, labels, bbox_to_anchor=(1.05, 1),
311
316
# loc=2, borderaxespad=0.)
317
+ # self.fig.subplots_adjust(right=0.80)
312
318
313
319
ax .legend (patches , labels , loc = 'best' )
314
320
315
- self .fig .subplots_adjust (top = 0.8 , right = 0.80 )
316
321
322
+ self .fig .subplots_adjust (top = 0.8 )
323
+
324
+ def _post_plot_logic (self ):
325
+ for ax in self .axes :
326
+ str_index = [_stringify (key ) for key in self .data .index ]
327
+ if self .kind == 'bar' :
328
+ ax .set_xlim ([self .ax_pos [0 ] - 0.25 , self .ax_pos [- 1 ] + 1 ])
329
+ ax .set_xticks (self .ax_pos + 0.375 )
330
+ ax .set_xticklabels (str_index , rotation = self .rot ,
331
+ fontsize = self .fontsize )
332
+ ax .axhline (0 , color = 'k' , linestyle = '--' )
333
+ else :
334
+ # horizontal bars
335
+ ax .set_ylim ([self .ax_pos [0 ] - 0.25 , self .ax_pos [- 1 ] + 1 ])
336
+ ax .set_yticks (self .ax_pos + 0.375 )
337
+ ax .set_yticklabels (str_index , rotation = self .rot ,
338
+ fontsize = self .fontsize )
339
+ ax .axvline (0 , color = 'k' , linestyle = '--' )
317
340
318
341
class BoxPlot (MPLPlot ):
319
342
pass
@@ -325,7 +348,7 @@ class HistPlot(MPLPlot):
325
348
326
349
def plot_frame (frame = None , subplots = False , sharex = True , sharey = False ,
327
350
use_index = True ,
328
- figsize = None , grid = True , legend = True , rot = 30 ,
351
+ figsize = None , grid = True , legend = True , rot = None ,
329
352
ax = None , title = None ,
330
353
xlim = None , ylim = None ,
331
354
xticks = None , yticks = None ,
@@ -843,8 +866,9 @@ def _subplots(nrows=1, ncols=1, sharex=False, sharey=False, squeeze=True,
843
866
reload (fr )
844
867
from pandas .core .frame import DataFrame
845
868
846
- data = DataFrame ([[3 , 6 ], [4 , 8 ], [4 , 9 ], [4 , 9 ], [2 , 5 ]],
847
- columns = ['A' , 'B' ])
869
+ data = DataFrame ([[3 , 6 , - 5 ], [4 , 8 , 2 ], [4 , 9 , - 6 ],
870
+ [4 , 9 , - 3 ], [2 , 5 , - 1 ]],
871
+ columns = ['A' , 'B' , 'C' ])
848
872
data .plot (kind = 'barh' , stacked = True )
849
873
850
874
plt .show ()
0 commit comments