/ / SPARK, ML, Tuning, CrossValidator: prístup k metrikám - apache-spark, apache-spark-mllib, apache-spark-ml

SPARK, ML, Ladenie, CrossValidator: prístup k metrikám - apache-spark, apache-spark-mllib, apache-spark-ml

S cieľom zostaviť klasifikátor viacerých tried NaiveBayes používam CrossValidator na výber najlepších parametrov v mojom potrubí:

val cv = new CrossValidator()
.setEstimator(pipeline)
.setEstimatorParamMaps(paramGrid)
.setEvaluator(new MulticlassClassificationEvaluator)
.setNumFolds(10)

val cvModel = cv.fit(trainingSet)

Potrubie obsahuje obvyklé transformátory a odhady v tomto poradí: Tokenizer, StopWordsRemover, HashingTF, IDF a nakoniec NaiveBayes.

Je možné získať prístup k metrikám vypočítaným pre najlepší model?

V ideálnom prípade by som chcel získať prístup k metrikám všetkých modelov a zistiť, ako zmena parametrov mení kvalitu klasifikácie. Ale momentálne je najlepší model dosť dobrý.

Pre informáciu, používam Spark 1.6.0

odpovede:

7 pre odpoveď č. 1

Tu je postup, ako to robím:

val pipeline = new Pipeline()
.setStages(Array(tokenizer, stopWordsFilter, tf, idf, word2Vec, featureVectorAssembler, categoryIndexerModel, classifier, categoryReverseIndexer))

...

val paramGrid = new ParamGridBuilder()
.addGrid(tf.numFeatures, Array(10, 100))
.addGrid(idf.minDocFreq, Array(1, 10))
.addGrid(word2Vec.vectorSize, Array(200, 300))
.addGrid(classifier.maxDepth, Array(3, 5))
.build()

paramGrid.size // 16 entries

...

// Print the average metrics per ParamGrid entry
val avgMetricsParamGrid = crossValidatorModel.avgMetrics

// Combine with paramGrid to see how they affect the overall metrics
val combined = paramGrid.zip(avgMetricsParamGrid)

...

val bestModel = crossValidatorModel.bestModel.asInstanceOf[PipelineModel]

// Explain params for each stage
val bestHashingTFNumFeatures = bestModel.stages(2).asInstanceOf[HashingTF].explainParams
val bestIDFMinDocFrequency = bestModel.stages(3).asInstanceOf[IDFModel].explainParams
val bestWord2VecVectorSize = bestModel.stages(4).asInstanceOf[Word2VecModel].explainParams
val bestDecisionTreeDepth = bestModel.stages(7).asInstanceOf[DecisionTreeClassificationModel].explainParams

1 pre odpoveď č. 2
 cvModel.avgMetrics

pracuje v pysparku 2.2.0