@@ -1022,7 +1022,10 @@ def _post_plot_logic(self):
10221022 def _adorn_subplots (self ):
10231023 to_adorn = self .axes
10241024
1025- # todo: sharex, sharey handling?
1025+ if len (self .axes ) > 0 :
1026+ all_axes = self ._get_axes ()
1027+ nrows , ncols = self ._get_axes_layout ()
1028+ _handle_shared_axes (all_axes , len (all_axes ), len (all_axes ), nrows , ncols , self .sharex , self .sharey )
10261029
10271030 for ax in to_adorn :
10281031 if self .yticks is not None :
@@ -1375,6 +1378,19 @@ def _get_errorbars(self, label=None, index=None, xerr=True, yerr=True):
13751378 errors [kw ] = err
13761379 return errors
13771380
1381+ def _get_axes (self ):
1382+ return self .axes [0 ].get_figure ().get_axes ()
1383+
1384+ def _get_axes_layout (self ):
1385+ axes = self ._get_axes ()
1386+ x_set = set ()
1387+ y_set = set ()
1388+ for ax in axes :
1389+ # check axes coordinates to estimate layout
1390+ points = ax .get_position ().get_points ()
1391+ x_set .add (points [0 ][0 ])
1392+ y_set .add (points [0 ][1 ])
1393+ return (len (y_set ), len (x_set ))
13781394
13791395class ScatterPlot (MPLPlot ):
13801396 _layout_type = 'single'
@@ -3231,6 +3247,28 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
32313247 ax = fig .add_subplot (nrows , ncols , i + 1 , ** kwds )
32323248 axarr [i ] = ax
32333249
3250+ _handle_shared_axes (axarr , nplots , naxes , nrows , ncols , sharex , sharey )
3251+
3252+ if naxes != nplots :
3253+ for ax in axarr [naxes :]:
3254+ ax .set_visible (False )
3255+
3256+ if squeeze :
3257+ # Reshape the array to have the final desired dimension (nrow,ncol),
3258+ # though discarding unneeded dimensions that equal 1. If we only have
3259+ # one subplot, just return it instead of a 1-element array.
3260+ if nplots == 1 :
3261+ axes = axarr [0 ]
3262+ else :
3263+ axes = axarr .reshape (nrows , ncols ).squeeze ()
3264+ else :
3265+ # returned axis array will be always 2-d, even if nrows=ncols=1
3266+ axes = axarr .reshape (nrows , ncols )
3267+
3268+ return fig , axes
3269+
3270+
3271+ def _handle_shared_axes (axarr , nplots , naxes , nrows , ncols , sharex , sharey ):
32343272 if nplots > 1 :
32353273
32363274 if sharex and nrows > 1 :
@@ -3241,8 +3279,11 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
32413279 # set_visible will not be effective if
32423280 # minor axis has NullLocator and NullFormattor (default)
32433281 import matplotlib .ticker as ticker
3244- ax .xaxis .set_minor_locator (ticker .AutoLocator ())
3245- ax .xaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
3282+
3283+ if isinstance (ax .xaxis .get_minor_locator (), ticker .NullLocator ):
3284+ ax .xaxis .set_minor_locator (ticker .AutoLocator ())
3285+ if isinstance (ax .xaxis .get_minor_formatter (), ticker .NullFormatter ):
3286+ ax .xaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
32463287 for label in ax .get_xticklabels (minor = True ):
32473288 label .set_visible (False )
32483289 except Exception : # pragma no cover
@@ -3255,32 +3296,16 @@ def _subplots(naxes=None, sharex=False, sharey=False, squeeze=True,
32553296 label .set_visible (False )
32563297 try :
32573298 import matplotlib .ticker as ticker
3258- ax .yaxis .set_minor_locator (ticker .AutoLocator ())
3259- ax .yaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
3299+ if isinstance (ax .yaxis .get_minor_locator (), ticker .NullLocator ):
3300+ ax .yaxis .set_minor_locator (ticker .AutoLocator ())
3301+ if isinstance (ax .yaxis .get_minor_formatter (), ticker .NullFormatter ):
3302+ ax .yaxis .set_minor_formatter (ticker .FormatStrFormatter ('' ))
32603303 for label in ax .get_yticklabels (minor = True ):
32613304 label .set_visible (False )
32623305 except Exception : # pragma no cover
32633306 pass
32643307 ax .yaxis .get_label ().set_visible (False )
32653308
3266- if naxes != nplots :
3267- for ax in axarr [naxes :]:
3268- ax .set_visible (False )
3269-
3270- if squeeze :
3271- # Reshape the array to have the final desired dimension (nrow,ncol),
3272- # though discarding unneeded dimensions that equal 1. If we only have
3273- # one subplot, just return it instead of a 1-element array.
3274- if nplots == 1 :
3275- axes = axarr [0 ]
3276- else :
3277- axes = axarr .reshape (nrows , ncols ).squeeze ()
3278- else :
3279- # returned axis array will be always 2-d, even if nrows=ncols=1
3280- axes = axarr .reshape (nrows , ncols )
3281-
3282- return fig , axes
3283-
32843309
32853310def _flatten (axes ):
32863311 if not com .is_list_like (axes ):
0 commit comments