Skip to content

Commit ae082f5

Browse files
committed
fix python style
1 parent 758bc24 commit ae082f5

File tree

1 file changed

+7
-6
lines changed

1 file changed

+7
-6
lines changed

python/pyspark/ml/tuning.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -170,7 +170,8 @@ def _to_java_impl(self):
170170
return java_estimator, java_epms, java_evaluator
171171

172172

173-
class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels, MLReadable, MLWritable):
173+
class CrossValidator(Estimator, ValidatorParams, HasParallelism, HasCollectSubModels,
174+
MLReadable, MLWritable):
174175
"""
175176
176177
K-fold cross validation performs model selection by splitting the dataset into a set of
@@ -260,7 +261,7 @@ def _fit(self, dataset):
260261
pool = ThreadPool(processes=min(self.getParallelism(), numModels))
261262
subModels = None
262263
collectSubModelsParam = self.getCollectSubModels()
263-
if (collectSubModelsParam == True):
264+
if collectSubModelsParam:
264265
subModels = [[None for j in range(numModels)] for i in range(nFolds)]
265266

266267
for i in range(nFolds):
@@ -273,7 +274,7 @@ def _fit(self, dataset):
273274
def singleTrain(paramMapIndex):
274275
paramMap = epm[paramMapIndex]
275276
model = est.fit(train, paramMap)
276-
if (collectSubModelsParam == True):
277+
if collectSubModelsParam:
277278
subModels[i][paramMapIndex] = model
278279
# TODO: duplicate evaluator to take extra params from input
279280
metric = eva.evaluate(model.transform(validation, paramMap))
@@ -378,7 +379,7 @@ def __init__(self, bestModel, avgMetrics=[], subModels=None):
378379
#: CrossValidator.estimatorParamMaps, in the corresponding order.
379380
self.avgMetrics = avgMetrics
380381
#: sub model list from cross validation
381-
self.subModels=subModels
382+
self.subModels = subModels
382383

383384
def _transform(self, dataset):
384385
return self.bestModel.transform(dataset)
@@ -545,13 +546,13 @@ def _fit(self, dataset):
545546

546547
subModels = None
547548
collectSubModelsParam = self.getCollectSubModels()
548-
if (collectSubModelsParam == True):
549+
if collectSubModelsParam:
549550
subModels = [None for i in range(numModels)]
550551

551552
def singleTrain(paramMapIndex):
552553
paramMap = epm[paramMapIndex]
553554
model = est.fit(train, paramMap)
554-
if (collectSubModelsParam):
555+
if collectSubModelsParam:
555556
subModels[paramMapIndex] = model
556557
metric = eva.evaluate(model.transform(validation, paramMap))
557558
return metric

0 commit comments

Comments
 (0)