@@ -170,7 +170,8 @@ def _to_java_impl(self):
170170 return java_estimator , java_epms , java_evaluator
171171
172172
173- class CrossValidator (Estimator , ValidatorParams , HasParallelism , HasCollectSubModels , MLReadable , MLWritable ):
173+ class CrossValidator (Estimator , ValidatorParams , HasParallelism , HasCollectSubModels ,
174+ MLReadable , MLWritable ):
174175 """
175176
176177 K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -260,7 +261,7 @@ def _fit(self, dataset):
260261 pool = ThreadPool (processes = min (self .getParallelism (), numModels ))
261262 subModels = None
262263 collectSubModelsParam = self .getCollectSubModels ()
263- if ( collectSubModelsParam == True ) :
264+ if collectSubModelsParam :
264265 subModels = [[None for j in range (numModels )] for i in range (nFolds )]
265266
266267 for i in range (nFolds ):
@@ -273,7 +274,7 @@ def _fit(self, dataset):
273274 def singleTrain (paramMapIndex ):
274275 paramMap = epm [paramMapIndex ]
275276 model = est .fit (train , paramMap )
276- if ( collectSubModelsParam == True ) :
277+ if collectSubModelsParam :
277278 subModels [i ][paramMapIndex ] = model
278279 # TODO: duplicate evaluator to take extra params from input
279280 metric = eva .evaluate (model .transform (validation , paramMap ))
@@ -378,7 +379,7 @@ def __init__(self, bestModel, avgMetrics=[], subModels=None):
378379 #: CrossValidator.estimatorParamMaps, in the corresponding order.
379380 self .avgMetrics = avgMetrics
380381 #: sub model list from cross validation
381- self .subModels = subModels
382+ self .subModels = subModels
382383
383384 def _transform (self , dataset ):
384385 return self .bestModel .transform (dataset )
@@ -545,13 +546,13 @@ def _fit(self, dataset):
545546
546547 subModels = None
547548 collectSubModelsParam = self .getCollectSubModels ()
548- if ( collectSubModelsParam == True ) :
549+ if collectSubModelsParam :
549550 subModels = [None for i in range (numModels )]
550551
551552 def singleTrain (paramMapIndex ):
552553 paramMap = epm [paramMapIndex ]
553554 model = est .fit (train , paramMap )
554- if ( collectSubModelsParam ) :
555+ if collectSubModelsParam :
555556 subModels [paramMapIndex ] = model
556557 metric = eva .evaluate (model .transform (validation , paramMap ))
557558 return metric
0 commit comments