隨機森林是決策樹的集合。 隨機森林結合許多決策樹,以減少過度擬合的風險。 spark.ml實現支援隨機森林,使用連續和分類特徵,做二分類和多分類以及迴歸。
匯入包
import org.apache.spark.sql.SparkSession import org.apache.spark.sql.Dataset import org.apache.spark.sql.Row import org.apache.spark.sql.DataFrame import org.apache.spark.sql.Column import org.apache.spark.sql.DataFrameReader import org.apache.spark.rdd.RDD import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder import org.apache.spark.sql.Encoder import org.apache.spark.sql.DataFrameStatFunctions import org.apache.spark.sql.functions._ import org.apache.spark.ml.linalg.Vectors import org.apache.spark.ml.feature.{ IndexToString, StringIndexer, VectorIndexer } import org.apache.spark.ml.feature.VectorAssembler import org.apache.spark.ml.Pipeline import org.apache.spark.ml.classification.{ RandomForestClassificationModel, RandomForestClassifier } import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator import org.apache.spark.ml.tuning.{ ParamGridBuilder, CrossValidator }
匯入源資料
// affairs:一年來婚外情的頻率 // gender:性別 // age:年齡 // yearsmarried:婚齡 // children:是否有小孩 // religiousness:宗教信仰程度(5分制,1分表示反對,5分表示非常信仰) // education:學歷 // occupation:職業(逆向編號的戈登7種分類) // rating:對婚姻的自我評分(5分制,1表示非常不幸福,5表示非常幸福) val spark = SparkSession.builder().appName("Spark Random Forest Classifier").config("spark.some.config.option", "some-value").getOrCreate() // For implicit conversions like converting RDDs to DataFrames import spark.implicits._ val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List( (0, "male", 37, 10, "no", 3, 18, 7, 4), (0, "female", 27, 4, "no", 4, 14, 6, 4), (0, "female", 32, 15, "yes", 1, 12, 1, 4), (0, "male", 57, 15, "yes", 5, 18, 6, 5), (0, "male", 22, 0.75, "no", 2, 17, 6, 3), (0, "female", 32, 1.5, "no", 2, 17, 5, 5), (0, "female", 22, 0.75, "no", 2, 12, 1, 3), (0, "male", 57, 15, "yes", 2, 14, 4, 4), (0, "female", 32, 15, "yes", 4, 16, 1, 2), (0, "male", 22, 1.5, "no", 4, 14, 4, 5), (0, "male", 37, 15, "yes", 2, 20, 7, 2), (0, "male", 27, 4, "yes", 4, 18, 6, 4), (0, "male", 47, 15, "yes", 5, 17, 6, 4), (0, "female", 22, 1.5, "no", 2, 17, 5, 4), (0, "female", 27, 4, "no", 4, 14, 5, 4), (0, "female", 37, 15, "yes", 1, 17, 5, 5), (0, "female", 37, 15, "yes", 2, 18, 4, 3), (0, "female", 22, 0.75, "no", 3, 16, 5, 4), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 27, 10, "yes", 2, 14, 1, 5), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 27, 10, "yes", 4, 16, 5, 4), (0, "female", 32, 10, "yes", 3, 14, 1, 5), (0, "male", 37, 4, "yes", 2, 20, 6, 4), (0, "female", 22, 1.5, "no", 2, 18, 5, 5), (0, "female", 27, 7, "no", 4, 16, 1, 5), (0, "male", 42, 15, "yes", 5, 20, 6, 4), (0, "male", 27, 4, "yes", 3, 16, 5, 5), (0, "female", 27, 4, "yes", 3, 17, 5, 4), (0, "male", 42, 15, "yes", 4, 20, 6, 3), (0, "female", 22, 1.5, "no", 3, 16, 5, 5), (0, "male", 27, 0.417, "no", 4, 17, 6, 4), (0, "female", 42, 15, "yes", 5, 14, 5, 4), (0, "male", 32, 4, "yes", 1, 18, 6, 4), (0, "female", 22, 1.5, "no", 4, 16, 5, 3), (0, "female", 42, 15, "yes", 3, 12, 1, 4), (0, "female", 22, 4, "no", 4, 17, 5, 5), (0, "male", 22, 1.5, "yes", 1, 14, 3, 5), (0, "female", 22, 0.75, "no", 3, 16, 1, 5), (0, "male", 32, 10, "yes", 5, 20, 6, 5), (0, "male", 52, 15, "yes", 5, 18, 6, 3), (0, "female", 22, 0.417, "no", 5, 14, 1, 4), (0, "female", 27, 4, "yes", 2, 18, 6, 1), (0, "female", 32, 7, "yes", 5, 17, 5, 3), (0, "male", 22, 4, "no", 3, 16, 5, 5), (0, "female", 27, 7, "yes", 4, 18, 6, 5), (0, "female", 42, 15, "yes", 2, 18, 5, 4), (0, "male", 27, 1.5, "yes", 4, 16, 3, 5), (0, "male", 42, 15, "yes", 2, 20, 6, 4), (0, "female", 22, 0.75, "no", 5, 14, 3, 5), (0, "male", 32, 7, "yes", 2, 20, 6, 4), (0, "male", 27, 4, "yes", 5, 20, 6, 5), (0, "male", 27, 10, "yes", 4, 20, 6, 4), (0, "male", 22, 4, "no", 1, 18, 5, 5), (0, "female", 37, 15, "yes", 4, 14, 3, 1), (0, "male", 22, 1.5, "yes", 5, 16, 4, 4), (0, "female", 37, 15, "yes", 4, 17, 1, 5), (0, "female", 27, 0.75, "no", 4, 17, 5, 4), (0, "male", 32, 10, "yes", 4, 20, 6, 4), (0, "female", 47, 15, "yes", 5, 14, 7, 2), (0, "male", 37, 10, "yes", 3, 20, 6, 4), (0, "female", 22, 0.75, "no", 2, 16, 5, 5), (0, "male", 27, 4, "no", 2, 18, 4, 5), (0, "male", 32, 7, "no", 4, 20, 6, 4), (0, "male", 42, 15, "yes", 2, 17, 3, 5), (0, "male", 37, 10, "yes", 4, 20, 6, 4), (0, "female", 47, 15, "yes", 3, 17, 6, 5), (0, "female", 22, 1.5, "no", 5, 16, 5, 5), (0, "female", 27, 1.5, "no", 2, 16, 6, 4), (0, "female", 27, 4, "no", 3, 17, 5, 5), (0, "female", 32, 10, "yes", 5, 14, 4, 5), (0, "female", 22, 0.125, "no", 2, 12, 5, 5), (0, "male", 47, 15, "yes", 4, 14, 4, 3), (0, "male", 32, 15, "yes", 1, 14, 5, 5), (0, "male", 27, 7, "yes", 4, 16, 5, 5), (0, "female", 22, 1.5, "yes", 3, 16, 5, 5), (0, "male", 27, 4, "yes", 3, 17, 6, 5), (0, "female", 22, 1.5, "no", 3, 16, 5, 5), (0, "male", 57, 15, "yes", 2, 14, 7, 2), (0, "male", 17.5, 1.5, "yes", 3, 18, 6, 5), (0, "male", 57, 15, "yes", 4, 20, 6, 5), (0, "female", 22, 0.75, "no", 2, 16, 3, 4), (0, "male", 42, 4, "no", 4, 17, 3, 3), (0, "female", 22, 1.5, "yes", 4, 12, 1, 5), (0, "female", 22, 0.417, "no", 1, 17, 6, 4), (0, "female", 32, 15, "yes", 4, 17, 5, 5), (0, "female", 27, 1.5, "no", 3, 18, 5, 2), (0, "female", 22, 1.5, "yes", 3, 14, 1, 5), (0, "female", 37, 15, "yes", 3, 14, 1, 4), (0, "female", 32, 15, "yes", 4, 14, 3, 4), (0, "male", 37, 10, "yes", 2, 14, 5, 3), (0, "male", 37, 10, "yes", 4, 16, 5, 4), (0, "male", 57, 15, "yes", 5, 20, 5, 3), (0, "male", 27, 0.417, "no", 1, 16, 3, 4), (0, "female", 42, 15, "yes", 5, 14, 1, 5), (0, "male", 57, 15, "yes", 3, 16, 6, 1), (0, "male", 37, 10, "yes", 1, 16, 6, 4), (0, "male", 37, 15, "yes", 3, 17, 5, 5), (0, "male", 37, 15, "yes", 4, 20, 6, 5), (0, "female", 27, 10, "yes", 5, 14, 1, 5), (0, "male", 37, 10, "yes", 2, 18, 6, 4), (0, "female", 22, 0.125, "no", 4, 12, 4, 5), (0, "male", 57, 15, "yes", 5, 20, 6, 5), (0, "female", 37, 15, "yes", 4, 18, 6, 4), (0, "male", 22, 4, "yes", 4, 14, 6, 4), (0, "male", 27, 7, "yes", 4, 18, 5, 4), (0, "male", 57, 15, "yes", 4, 20, 5, 4), (0, "male", 32, 15, "yes", 3, 14, 6, 3), (0, "female", 22, 1.5, "no", 2, 14, 5, 4), (0, "female", 32, 7, "yes", 4, 17, 1, 5), (0, "female", 37, 15, "yes", 4, 17, 6, 5), (0, "female", 32, 1.5, "no", 5, 18, 5, 5), (0, "male", 42, 10, "yes", 5, 20, 7, 4), (0, "female", 27, 7, "no", 3, 16, 5, 4), (0, "male", 37, 15, "no", 4, 20, 6, 5), (0, "male", 37, 15, "yes", 4, 14, 3, 2), (0, "male", 32, 10, "no", 5, 18, 6, 4), (0, "female", 22, 0.75, "no", 4, 16, 1, 5), (0, "female", 27, 7, "yes", 4, 12, 2, 4), (0, "female", 27, 7, "yes", 2, 16, 2, 5), (0, "female", 42, 15, "yes", 5, 18, 5, 4), (0, "male", 42, 15, "yes", 4, 17, 5, 3), (0, "female", 27, 7, "yes", 2, 16, 1, 2), (0, "female", 22, 1.5, "no", 3, 16, 5, 5), (0, "male", 37, 15, "yes", 5, 20, 6, 5), (0, "female", 22, 0.125, "no", 2, 14, 4, 5), (0, "male", 27, 1.5, "no", 4, 16, 5, 5), (0, "male", 32, 1.5, "no", 2, 18, 6, 5), (0, "male", 27, 1.5, "no", 2, 17, 6, 5), (0, "female", 27, 10, "yes", 4, 16, 1, 3), (0, "male", 42, 15, "yes", 4, 18, 6, 5), (0, "female", 27, 1.5, "no", 2, 16, 6, 5), (0, "male", 27, 4, "no", 2, 18, 6, 3), (0, "female", 32, 10, "yes", 3, 14, 5, 3), (0, "female", 32, 15, "yes", 3, 18, 5, 4), (0, "female", 22, 0.75, "no", 2, 18, 6, 5), (0, "female", 37, 15, "yes", 2, 16, 1, 4), (0, "male", 27, 4, "yes", 4, 20, 5, 5), (0, "male", 27, 4, "no", 1, 20, 5, 4), (0, "female", 27, 10, "yes", 2, 12, 1, 4), (0, "female", 32, 15, "yes", 5, 18, 6, 4), (0, "male", 27, 7, "yes", 5, 12, 5, 3), (0, "male", 52, 15, "yes", 2, 18, 5, 4), (0, "male", 27, 4, "no", 3, 20, 6, 3), (0, "male", 37, 4, "yes", 1, 18, 5, 4), (0, "male", 27, 4, "yes", 4, 14, 5, 4), (0, "female", 52, 15, "yes", 5, 12, 1, 3), (0, "female", 57, 15, "yes", 4, 16, 6, 4), (0, "male", 27, 7, "yes", 1, 16, 5, 4), (0, "male", 37, 7, "yes", 4, 20, 6, 3), (0, "male", 22, 0.75, "no", 2, 14, 4, 3), (0, "male", 32, 4, "yes", 2, 18, 5, 3), (0, "male", 37, 15, "yes", 4, 20, 6, 3), (0, "male", 22, 0.75, "yes", 2, 14, 4, 3), (0, "male", 42, 15, "yes", 4, 20, 6, 3), (0, "female", 52, 15, "yes", 5, 17, 1, 1), (0, "female", 37, 15, "yes", 4, 14, 1, 2), (0, "male", 27, 7, "yes", 4, 14, 5, 3), (0, "male", 32, 4, "yes", 2, 16, 5, 5), (0, "female", 27, 4, "yes", 2, 18, 6, 5), (0, "female", 27, 4, "yes", 2, 18, 5, 5), (0, "male", 37, 15, "yes", 5, 18, 6, 5), (0, "female", 47, 15, "yes", 5, 12, 5, 4), (0, "female", 32, 10, "yes", 3, 17, 1, 4), (0, "female", 27, 1.5, "yes", 4, 17, 1, 2), (0, "female", 57, 15, "yes", 2, 18, 5, 2), (0, "female", 22, 1.5, "no", 4, 14, 5, 4), (0, "male", 42, 15, "yes", 3, 14, 3, 4), (0, "male", 57, 15, "yes", 4, 9, 2, 2), (0, "male", 57, 15, "yes", 4, 20, 6, 5), (0, "female", 22, 0.125, "no", 4, 14, 4, 5), (0, "female", 32, 10, "yes", 4, 14, 1, 5), (0, "female", 42, 15, "yes", 3, 18, 5, 4), (0, "female", 27, 1.5, "no", 2, 18, 6, 5), (0, "male", 32, 0.125, "yes", 2, 18, 5, 2), (0, "female", 27, 4, "no", 3, 16, 5, 4), (0, "female", 27, 10, "yes", 2, 16, 1, 4), (0, "female", 32, 7, "yes", 4, 16, 1, 3), (0, "female", 37, 15, "yes", 4, 14, 5, 4), (0, "female", 42, 15, "yes", 5, 17, 6, 2), (0, "male", 32, 1.5, "yes", 4, 14, 6, 5), (0, "female", 32, 4, "yes", 3, 17, 5, 3), (0, "female", 37, 7, "no", 4, 18, 5, 5), (0, "female", 22, 0.417, "yes", 3, 14, 3, 5), (0, "female", 27, 7, "yes", 4, 14, 1, 5), (0, "male", 27, 0.75, "no", 3, 16, 5, 5), (0, "male", 27, 4, "yes", 2, 20, 5, 5), (0, "male", 32, 10, "yes", 4, 16, 4, 5), (0, "male", 32, 15, "yes", 1, 14, 5, 5), (0, "male", 22, 0.75, "no", 3, 17, 4, 5), (0, "female", 27, 7, "yes", 4, 17, 1, 4), (0, "male", 27, 0.417, "yes", 4, 20, 5, 4), (0, "male", 37, 15, "yes", 4, 20, 5, 4), (0, "female", 37, 15, "yes", 2, 14, 1, 3), (0, "male", 22, 4, "yes", 1, 18, 5, 4), (0, "male", 37, 15, "yes", 4, 17, 5, 3), (0, "female", 22, 1.5, "no", 2, 14, 4, 5), (0, "male", 52, 15, "yes", 4, 14, 6, 2), (0, "female", 22, 1.5, "no", 4, 17, 5, 5), (0, "male", 32, 4, "yes", 5, 14, 3, 5), (0, "male", 32, 4, "yes", 2, 14, 3, 5), (0, "female", 22, 1.5, "no", 3, 16, 6, 5), (0, "male", 27, 0.75, "no", 2, 18, 3, 3), (0, "female", 22, 7, "yes", 2, 14, 5, 2), (0, "female", 27, 0.75, "no", 2, 17, 5, 3), (0, "female", 37, 15, "yes", 4, 12, 1, 2), (0, "female", 22, 1.5, "no", 1, 14, 1, 5), (0, "female", 37, 10, "no", 2, 12, 4, 4), (0, "female", 37, 15, "yes", 4, 18, 5, 3), (0, "female", 42, 15, "yes", 3, 12, 3, 3), (0, "male", 22, 4, "no", 2, 18, 5, 5), (0, "male", 52, 7, "yes", 2, 20, 6, 2), (0, "male", 27, 0.75, "no", 2, 17, 5, 5), (0, "female", 27, 4, "no", 2, 17, 4, 5), (0, "male", 42, 1.5, "no", 5, 20, 6, 5), (0, "male", 22, 1.5, "no", 4, 17, 6, 5), (0, "male", 22, 4, "no", 4, 17, 5, 3), (0, "female", 22, 4, "yes", 1, 14, 5, 4), (0, "male", 37, 15, "yes", 5, 20, 4, 5), (0, "female", 37, 10, "yes", 3, 16, 6, 3), (0, "male", 42, 15, "yes", 4, 17, 6, 5), (0, "female", 47, 15, "yes", 4, 17, 5, 5), (0, "male", 22, 1.5, "no", 4, 16, 5, 4), (0, "female", 32, 10, "yes", 3, 12, 1, 4), (0, "female", 22, 7, "yes", 1, 14, 3, 5), (0, "female", 32, 10, "yes", 4, 17, 5, 4), (0, "male", 27, 1.5, "yes", 2, 16, 2, 4), (0, "male", 37, 15, "yes", 4, 14, 5, 5), (0, "male", 42, 4, "yes", 3, 14, 4, 5), (0, "female", 37, 15, "yes", 5, 14, 5, 4), (0, "female", 32, 7, "yes", 4, 17, 5, 5), (0, "female", 42, 15, "yes", 4, 18, 6, 5), (0, "male", 27, 4, "no", 4, 18, 6, 4), (0, "male", 22, 0.75, "no", 4, 18, 6, 5), (0, "male", 27, 4, "yes", 4, 14, 5, 3), (0, "female", 22, 0.75, "no", 5, 18, 1, 5), (0, "female", 52, 15, "yes", 5, 9, 5, 5), (0, "male", 32, 10, "yes", 3, 14, 5, 5), (0, "female", 37, 15, "yes", 4, 16, 4, 4), (0, "male", 32, 7, "yes", 2, 20, 5, 4), (0, "female", 42, 15, "yes", 3, 18, 1, 4), (0, "male", 32, 15, "yes", 1, 16, 5, 5), (0, "male", 27, 4, "yes", 3, 18, 5, 5), (0, "female", 32, 15, "yes", 4, 12, 3, 4), (0, "male", 22, 0.75, "yes", 3, 14, 2, 4), (0, "female", 22, 1.5, "no", 3, 16, 5, 3), (0, "female", 42, 15, "yes", 4, 14, 3, 5), (0, "female", 52, 15, "yes", 3, 16, 5, 4), (0, "male", 37, 15, "yes", 5, 20, 6, 4), (0, "female", 47, 15, "yes", 4, 12, 2, 3), (0, "male", 57, 15, "yes", 2, 20, 6, 4), (0, "male", 32, 7, "yes", 4, 17, 5, 5), (0, "female", 27, 7, "yes", 4, 17, 1, 4), (0, "male", 22, 1.5, "no", 1, 18, 6, 5), (0, "female", 22, 4, "yes", 3, 9, 1, 4), (0, "female", 22, 1.5, "no", 2, 14, 1, 5), (0, "male", 42, 15, "yes", 2, 20, 6, 4), (0, "male", 57, 15, "yes", 4, 9, 2, 4), (0, "female", 27, 7, "yes", 2, 18, 1, 5), (0, "female", 22, 4, "yes", 3, 14, 1, 5), (0, "male", 37, 15, "yes", 4, 14, 5, 3), (0, "male", 32, 7, "yes", 1, 18, 6, 4), (0, "female", 22, 1.5, "no", 2, 14, 5, 5), (0, "female", 22, 1.5, "yes", 3, 12, 1, 3), (0, "male", 52, 15, "yes", 2, 14, 5, 5), (0, "female", 37, 15, "yes", 2, 14, 1, 1), (0, "female", 32, 10, "yes", 2, 14, 5, 5), (0, "male", 42, 15, "yes", 4, 20, 4, 5), (0, "female", 27, 4, "yes", 3, 18, 4, 5), (0, "male", 37, 15, "yes", 4, 20, 6, 5), (0, "male", 27, 1.5, "no", 3, 18, 5, 5), (0, "female", 22, 0.125, "no", 2, 16, 6, 3), (0, "male", 32, 10, "yes", 2, 20, 6, 3), (0, "female", 27, 4, "no", 4, 18, 5, 4), (0, "female", 27, 7, "yes", 2, 12, 5, 1), (0, "male", 32, 4, "yes", 5, 18, 6, 3), (0, "female", 37, 15, "yes", 2, 17, 5, 5), (0, "male", 47, 15, "no", 4, 20, 6, 4), (0, "male", 27, 1.5, "no", 1, 18, 5, 5), (0, "male", 37, 15, "yes", 4, 20, 6, 4), (0, "female", 32, 15, "yes", 4, 18, 1, 4), (0, "female", 32, 7, "yes", 4, 17, 5, 4), (0, "female", 42, 15, "yes", 3, 14, 1, 3), (0, "female", 27, 7, "yes", 3, 16, 1, 4), (0, "male", 27, 1.5, "no", 3, 16, 4, 2), (0, "male", 22, 1.5, "no", 3, 16, 3, 5), (0, "male", 27, 4, "yes", 3, 16, 4, 2), (0, "female", 27, 7, "yes", 3, 12, 1, 2), (0, "female", 37, 15, "yes", 2, 18, 5, 4), (0, "female", 37, 7, "yes", 3, 14, 4, 4), (0, "male", 22, 1.5, "no", 2, 16, 5, 5), (0, "male", 37, 15, "yes", 5, 20, 5, 4), (0, "female", 22, 1.5, "no", 4, 16, 5, 3), (0, "female", 32, 10, "yes", 4, 16, 1, 5), (0, "male", 27, 4, "no", 2, 17, 5, 3), (0, "female", 22, 0.417, "no", 4, 14, 5, 5), (0, "female", 27, 4, "no", 2, 18, 5, 5), (0, "male", 37, 15, "yes", 4, 18, 5, 3), (0, "male", 37, 10, "yes", 5, 20, 7, 4), (0, "female", 27, 7, "yes", 2, 14, 4, 2), (0, "male", 32, 4, "yes", 2, 16, 5, 5), (0, "male", 32, 4, "yes", 2, 16, 6, 4), (0, "male", 22, 1.5, "no", 3, 18, 4, 5), (0, "female", 22, 4, "yes", 4, 14, 3, 4), (0, "female", 17.5, 0.75, "no", 2, 18, 5, 4), (0, "male", 32, 10, "yes", 4, 20, 4, 5), (0, "female", 32, 0.75, "no", 5, 14, 3, 3), (0, "male", 37, 15, "yes", 4, 17, 5, 3), (0, "male", 32, 4, "no", 3, 14, 4, 5), (0, "female", 27, 1.5, "no", 2, 17, 3, 2), (0, "female", 22, 7, "yes", 4, 14, 1, 5), (0, "male", 47, 15, "yes", 5, 14, 6, 5), (0, "male", 27, 4, "yes", 1, 16, 4, 4), (0, "female", 37, 15, "yes", 5, 14, 1, 3), (0, "male", 42, 4, "yes", 4, 18, 5, 5), (0, "female", 32, 4, "yes", 2, 14, 1, 5), (0, "male", 52, 15, "yes", 2, 14, 7, 4), (0, "female", 22, 1.5, "no", 2, 16, 1, 4), (0, "male", 52, 15, "yes", 4, 12, 2, 4), (0, "female", 22, 0.417, "no", 3, 17, 1, 5), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "male", 27, 4, "yes", 4, 20, 6, 4), (0, "female", 32, 15, "yes", 4, 14, 1, 5), (0, "female", 27, 1.5, "no", 2, 16, 3, 5), (0, "male", 32, 4, "no", 1, 20, 6, 5), (0, "male", 37, 15, "yes", 3, 20, 6, 4), (0, "female", 32, 10, "no", 2, 16, 6, 5), (0, "female", 32, 10, "yes", 5, 14, 5, 5), (0, "male", 37, 1.5, "yes", 4, 18, 5, 3), (0, "male", 32, 1.5, "no", 2, 18, 4, 4), (0, "female", 32, 10, "yes", 4, 14, 1, 4), (0, "female", 47, 15, "yes", 4, 18, 5, 4), (0, "female", 27, 10, "yes", 5, 12, 1, 5), (0, "male", 27, 4, "yes", 3, 16, 4, 5), (0, "female", 37, 15, "yes", 4, 12, 4, 2), (0, "female", 27, 0.75, "no", 4, 16, 5, 5), (0, "female", 37, 15, "yes", 4, 16, 1, 5), (0, "female", 32, 15, "yes", 3, 16, 1, 5), (0, "female", 27, 10, "yes", 2, 16, 1, 5), (0, "male", 27, 7, "no", 2, 20, 6, 5), (0, "female", 37, 15, "yes", 2, 14, 1, 3), (0, "male", 27, 1.5, "yes", 2, 17, 4, 4), (0, "female", 22, 0.75, "yes", 2, 14, 1, 5), (0, "male", 22, 4, "yes", 4, 14, 2, 4), (0, "male", 42, 0.125, "no", 4, 17, 6, 4), (0, "male", 27, 1.5, "yes", 4, 18, 6, 5), (0, "male", 27, 7, "yes", 3, 16, 6, 3), (0, "female", 52, 15, "yes", 4, 14, 1, 3), (0, "male", 27, 1.5, "no", 5, 20, 5, 2), (0, "female", 27, 1.5, "no", 2, 16, 5, 5), (0, "female", 27, 1.5, "no", 3, 17, 5, 5), (0, "male", 22, 0.125, "no", 5, 16, 4, 4), (0, "female", 27, 4, "yes", 4, 16, 1, 5), (0, "female", 27, 4, "yes", 4, 12, 1, 5), (0, "female", 47, 15, "yes", 2, 14, 5, 5), (0, "female", 32, 15, "yes", 3, 14, 5, 3), (0, "male", 42, 7, "yes", 2, 16, 5, 5), (0, "male", 22, 0.75, "no", 4, 16, 6, 4), (0, "male", 27, 0.125, "no", 3, 20, 6, 5), (0, "male", 32, 10, "yes", 3, 20, 6, 5), (0, "female", 22, 0.417, "no", 5, 14, 4, 5), (0, "female", 47, 15, "yes", 5, 14, 1, 4), (0, "female", 32, 10, "yes", 3, 14, 1, 5), (0, "male", 57, 15, "yes", 4, 17, 5, 5), (0, "male", 27, 4, "yes", 3, 20, 6, 5), (0, "female", 32, 7, "yes", 4, 17, 1, 5), (0, "female", 37, 10, "yes", 4, 16, 1, 5), (0, "female", 32, 10, "yes", 1, 18, 1, 4), (0, "female", 22, 4, "no", 3, 14, 1, 4), (0, "female", 27, 7, "yes", 4, 14, 3, 2), (0, "male", 57, 15, "yes", 5, 18, 5, 2), (0, "male", 32, 7, "yes", 2, 18, 5, 5), (0, "female", 27, 1.5, "no", 4, 17, 1, 3), (0, "male", 22, 1.5, "no", 4, 14, 5, 5), (0, "female", 22, 1.5, "yes", 4, 14, 5, 4), (0, "female", 32, 7, "yes", 3, 16, 1, 5), (0, "female", 47, 15, "yes", 3, 16, 5, 4), (0, "female", 22, 0.75, "no", 3, 16, 1, 5), (0, "female", 22, 1.5, "yes", 2, 14, 5, 5), (0, "female", 27, 4, "yes", 1, 16, 5, 5), (0, "male", 52, 15, "yes", 4, 16, 5, 5), (0, "male", 32, 10, "yes", 4, 20, 6, 5), (0, "male", 47, 15, "yes", 4, 16, 6, 4), (0, "female", 27, 7, "yes", 2, 14, 1, 2), (0, "female", 22, 1.5, "no", 4, 14, 4, 5), (0, "female", 32, 10, "yes", 2, 16, 5, 4), (0, "female", 22, 0.75, "no", 2, 16, 5, 4), (0, "female", 22, 1.5, "no", 2, 16, 5, 5), (0, "female", 42, 15, "yes", 3, 18, 6, 4), (0, "female", 27, 7, "yes", 5, 14, 4, 5), (0, "male", 42, 15, "yes", 4, 16, 4, 4), (0, "female", 57, 15, "yes", 3, 18, 5, 2), (0, "male", 42, 15, "yes", 3, 18, 6, 2), (0, "female", 32, 7, "yes", 2, 14, 1, 2), (0, "male", 22, 4, "no", 5, 12, 4, 5), (0, "female", 22, 1.5, "no", 1, 16, 6, 5), (0, "female", 22, 0.75, "no", 1, 14, 4, 5), (0, "female", 32, 15, "yes", 4, 12, 1, 5), (0, "male", 22, 1.5, "no", 2, 18, 5, 3), (0, "male", 27, 4, "yes", 5, 17, 2, 5), (0, "female", 27, 4, "yes", 4, 12, 1, 5), (0, "male", 42, 15, "yes", 5, 18, 5, 4), (0, "male", 32, 1.5, "no", 2, 20, 7, 3), (0, "male", 57, 15, "no", 4, 9, 3, 1), (0, "male", 37, 7, "no", 4, 18, 5, 5), (0, "male", 52, 15, "yes", 2, 17, 5, 4), (0, "male", 47, 15, "yes", 4, 17, 6, 5), (0, "female", 27, 7, "no", 2, 17, 5, 4), (0, "female", 27, 7, "yes", 4, 14, 5, 5), (0, "female", 22, 4, "no", 2, 14, 3, 3), (0, "male", 37, 7, "yes", 2, 20, 6, 5), (0, "male", 27, 7, "no", 4, 12, 4, 3), (0, "male", 42, 10, "yes", 4, 18, 6, 4), (0, "female", 22, 1.5, "no", 3, 14, 1, 5), (0, "female", 22, 4, "yes", 2, 14, 1, 3), (0, "female", 57, 15, "no", 4, 20, 6, 5), (0, "male", 37, 15, "yes", 4, 14, 4, 3), (0, "female", 27, 7, "yes", 3, 18, 5, 5), (0, "female", 17.5, 10, "no", 4, 14, 4, 5), (0, "male", 22, 4, "yes", 4, 16, 5, 5), (0, "female", 27, 4, "yes", 2, 16, 1, 4), (0, "female", 37, 15, "yes", 2, 14, 5, 1), (0, "female", 22, 1.5, "no", 5, 14, 1, 4), (0, "male", 27, 7, "yes", 2, 20, 5, 4), (0, "male", 27, 4, "yes", 4, 14, 5, 5), (0, "male", 22, 0.125, "no", 1, 16, 3, 5), (0, "female", 27, 7, "yes", 4, 14, 1, 4), (0, "female", 32, 15, "yes", 5, 16, 5, 3), (0, "male", 32, 10, "yes", 4, 18, 5, 4), (0, "female", 32, 15, "yes", 2, 14, 3, 4), (0, "female", 22, 1.5, "no", 3, 17, 5, 5), (0, "male", 27, 4, "yes", 4, 17, 4, 4), (0, "female", 52, 15, "yes", 5, 14, 1, 5), (0, "female", 27, 7, "yes", 2, 12, 1, 2), (0, "female", 27, 7, "yes", 3, 12, 1, 4), (0, "female", 42, 15, "yes", 2, 14, 1, 4), (0, "female", 42, 15, "yes", 4, 14, 5, 4), (0, "male", 27, 7, "yes", 4, 14, 3, 3), (0, "male", 27, 7, "yes", 2, 20, 6, 2), (0, "female", 42, 15, "yes", 3, 12, 3, 3), (0, "male", 27, 4, "yes", 3, 16, 3, 5), (0, "female", 27, 7, "yes", 3, 14, 1, 4), (0, "female", 22, 1.5, "no", 2, 14, 4, 5), (0, "female", 27, 4, "yes", 4, 14, 1, 4), (0, "female", 22, 4, "no", 4, 14, 5, 5), (0, "female", 22, 1.5, "no", 2, 16, 4, 5), (0, "male", 47, 15, "no", 4, 14, 5, 4), (0, "male", 37, 10, "yes", 2, 18, 6, 2), (0, "male", 37, 15, "yes", 3, 17, 5, 4), (0, "female", 27, 4, "yes", 2, 16, 1, 4), (3, "male", 27, 1.5, "no", 3, 18, 4, 4), (3, "female", 27, 4, "yes", 3, 17, 1, 5), (7, "male", 37, 15, "yes", 5, 18, 6, 2), (12, "female", 32, 10, "yes", 3, 17, 5, 2), (1, "male", 22, 0.125, "no", 4, 16, 5, 5), (1, "female", 22, 1.5, "yes", 2, 14, 1, 5), (12, "male", 37, 15, "yes", 4, 14, 5, 2), (7, "female", 22, 1.5, "no", 2, 14, 3, 4), (2, "male", 37, 15, "yes", 2, 18, 6, 4), (3, "female", 32, 15, "yes", 4, 12, 3, 2), (1, "female", 37, 15, "yes", 4, 14, 4, 2), (7, "female", 42, 15, "yes", 3, 17, 1, 4), (12, "female", 42, 15, "yes", 5, 9, 4, 1), (12, "male", 37, 10, "yes", 2, 20, 6, 2), (12, "female", 32, 15, "yes", 3, 14, 1, 2), (3, "male", 27, 4, "no", 1, 18, 6, 5), (7, "male", 37, 10, "yes", 2, 18, 7, 3), (7, "female", 27, 4, "no", 3, 17, 5, 5), (1, "male", 42, 15, "yes", 4, 16, 5, 5), (1, "female", 47, 15, "yes", 5, 14, 4, 5), (7, "female", 27, 4, "yes", 3, 18, 5, 4), (1, "female", 27, 7, "yes", 5, 14, 1, 4), (12, "male", 27, 1.5, "yes", 3, 17, 5, 4), (12, "female", 27, 7, "yes", 4, 14, 6, 2), (3, "female", 42, 15, "yes", 4, 16, 5, 4), (7, "female", 27, 10, "yes", 4, 12, 7, 3), (1, "male", 27, 1.5, "no", 2, 18, 5, 2), (1, "male", 32, 4, "no", 4, 20, 6, 4), (1, "female", 27, 7, "yes", 3, 14, 1, 3), (3, "female", 32, 10, "yes", 4, 14, 1, 4), (3, "male", 27, 4, "yes", 2, 18, 7, 2), (1, "female", 17.5, 0.75, "no", 5, 14, 4, 5), (1, "female", 32, 10, "yes", 4, 18, 1, 5), (7, "female", 32, 7, "yes", 2, 17, 6, 4), (7, "male", 37, 15, "yes", 2, 20, 6, 4), (7, "female", 37, 10, "no", 1, 20, 5, 3), (12, "female", 32, 10, "yes", 2, 16, 5, 5), (7, "male", 52, 15, "yes", 2, 20, 6, 4), (7, "female", 42, 15, "yes", 1, 12, 1, 3), (1, "male", 52, 15, "yes", 2, 20, 6, 3), (2, "male", 37, 15, "yes", 3, 18, 6, 5), (12, "female", 22, 4, "no", 3, 12, 3, 4), (12, "male", 27, 7, "yes", 1, 18, 6, 2), (1, "male", 27, 4, "yes", 3, 18, 5, 5), (12, "male", 47, 15, "yes", 4, 17, 6, 5), (12, "female", 42, 15, "yes", 4, 12, 1, 1), (7, "male", 27, 4, "no", 3, 14, 3, 4), (7, "female", 32, 7, "yes", 4, 18, 4, 5), (1, "male", 32, 0.417, "yes", 3, 12, 3, 4), (3, "male", 47, 15, "yes", 5, 16, 5, 4), (12, "male", 37, 15, "yes", 2, 20, 5, 4), (7, "male", 22, 4, "yes", 2, 17, 6, 4), (1, "male", 27, 4, "no", 2, 14, 4, 5), (7, "female", 52, 15, "yes", 5, 16, 1, 3), (1, "male", 27, 4, "no", 3, 14, 3, 3), (1, "female", 27, 10, "yes", 4, 16, 1, 4), (1, "male", 32, 7, "yes", 3, 14, 7, 4), (7, "male", 32, 7, "yes", 2, 18, 4, 1), (3, "male", 22, 1.5, "no", 1, 14, 3, 2), (7, "male", 22, 4, "yes", 3, 18, 6, 4), (7, "male", 42, 15, "yes", 4, 20, 6, 4), (2, "female", 57, 15, "yes", 1, 18, 5, 4), (7, "female", 32, 4, "yes", 3, 18, 5, 2), (1, "male", 27, 4, "yes", 1, 16, 4, 4), (7, "male", 32, 7, "yes", 4, 16, 1, 4), (2, "male", 57, 15, "yes", 1, 17, 4, 4), (7, "female", 42, 15, "yes", 4, 14, 5, 2), (7, "male", 37, 10, "yes", 1, 18, 5, 3), (3, "male", 42, 15, "yes", 3, 17, 6, 1), (1, "female", 52, 15, "yes", 3, 14, 4, 4), (2, "female", 27, 7, "yes", 3, 17, 5, 3), (12, "male", 32, 7, "yes", 2, 12, 4, 2), (1, "male", 22, 4, "no", 4, 14, 2, 5), (3, "male", 27, 7, "yes", 3, 18, 6, 4), (12, "female", 37, 15, "yes", 1, 18, 5, 5), (7, "female", 32, 15, "yes", 3, 17, 1, 3), (7, "female", 27, 7, "no", 2, 17, 5, 5), (1, "female", 32, 7, "yes", 3, 17, 5, 3), (1, "male", 32, 1.5, "yes", 2, 14, 2, 4), (12, "female", 42, 15, "yes", 4, 14, 1, 2), (7, "male", 32, 10, "yes", 3, 14, 5, 4), (7, "male", 37, 4, "yes", 1, 20, 6, 3), (1, "female", 27, 4, "yes", 2, 16, 5, 3), (12, "female", 42, 15, "yes", 3, 14, 4, 3), (1, "male", 27, 10, "yes", 5, 20, 6, 5), (12, "male", 37, 10, "yes", 2, 20, 6, 2), (12, "female", 27, 7, "yes", 1, 14, 3, 3), (3, "female", 27, 7, "yes", 4, 12, 1, 2), (3, "male", 32, 10, "yes", 2, 14, 4, 4), (12, "female", 17.5, 0.75, "yes", 2, 12, 1, 3), (12, "female", 32, 15, "yes", 3, 18, 5, 4), (2, "female", 22, 7, "no", 4, 14, 4, 3), (1, "male", 32, 7, "yes", 4, 20, 6, 5), (7, "male", 27, 4, "yes", 2, 18, 6, 2), (1, "female", 22, 1.5, "yes", 5, 14, 5, 3), (12, "female", 32, 15, "no", 3, 17, 5, 1), (12, "female", 42, 15, "yes", 2, 12, 1, 2), (7, "male", 42, 15, "yes", 3, 20, 5, 4), (12, "male", 32, 10, "no", 2, 18, 4, 2), (12, "female", 32, 15, "yes", 3, 9, 1, 1), (7, "male", 57, 15, "yes", 5, 20, 4, 5), (12, "male", 47, 15, "yes", 4, 20, 6, 4), (2, "female", 42, 15, "yes", 2, 17, 6, 3), (12, "male", 37, 15, "yes", 3, 17, 6, 3), (12, "male", 37, 15, "yes", 5, 17, 5, 2), (7, "male", 27, 10, "yes", 2, 20, 6, 4), (2, "male", 37, 15, "yes", 2, 16, 5, 4), (12, "female", 32, 15, "yes", 1, 14, 5, 2), (7, "male", 32, 10, "yes", 3, 17, 6, 3), (2, "male", 37, 15, "yes", 4, 18, 5, 1), (7, "female", 27, 1.5, "no", 2, 17, 5, 5), (3, "female", 47, 15, "yes", 2, 17, 5, 2), (12, "male", 37, 15, "yes", 2, 17, 5, 4), (12, "female", 27, 4, "no", 2, 14, 5, 5), (2, "female", 27, 10, "yes", 4, 14, 1, 5), (1, "female", 22, 4, "yes", 3, 16, 1, 3), (12, "male", 52, 7, "no", 4, 16, 5, 5), (2, "female", 27, 4, "yes", 1, 16, 3, 5), (7, "female", 37, 15, "yes", 2, 17, 6, 4), (2, "female", 27, 4, "no", 1, 17, 3, 1), (12, "female", 17.5, 0.75, "yes", 2, 12, 3, 5), (7, "female", 32, 15, "yes", 5, 18, 5, 4), (7, "female", 22, 4, "no", 1, 16, 3, 5), (2, "male", 32, 4, "yes", 4, 18, 6, 4), (1, "female", 22, 1.5, "yes", 3, 18, 5, 2), (3, "female", 42, 15, "yes", 2, 17, 5, 4), (1, "male", 32, 7, "yes", 4, 16, 4, 4), (12, "male", 37, 15, "no", 3, 14, 6, 2), (1, "male", 42, 15, "yes", 3, 16, 6, 3), (1, "male", 27, 4, "yes", 1, 18, 5, 4), (2, "male", 37, 15, "yes", 4, 20, 7, 3), (7, "male", 37, 15, "yes", 3, 20, 6, 4), (3, "male", 22, 1.5, "no", 2, 12, 3, 3), (3, "male", 32, 4, "yes", 3, 20, 6, 2), (2, "male", 32, 15, "yes", 5, 20, 6, 5), (12, "female", 52, 15, "yes", 1, 18, 5, 5), (12, "male", 47, 15, "no", 1, 18, 6, 5), (3, "female", 32, 15, "yes", 4, 16, 4, 4), (7, "female", 32, 15, "yes", 3, 14, 3, 2), (7, "female", 27, 7, "yes", 4, 16, 1, 2), (12, "male", 42, 15, "yes", 3, 18, 6, 2), (7, "female", 42, 15, "yes", 2, 14, 3, 2), (12, "male", 27, 7, "yes", 2, 17, 5, 4), (3, "male", 32, 10, "yes", 4, 14, 4, 3), (7, "male", 47, 15, "yes", 3, 16, 4, 2), (1, "male", 22, 1.5, "yes", 1, 12, 2, 5), (7, "female", 32, 10, "yes", 2, 18, 5, 4), (2, "male", 32, 10, "yes", 2, 17, 6, 5), (2, "male", 22, 7, "yes", 3, 18, 6, 2), (1, "female", 32, 15, "yes", 3, 14, 1, 5)) val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")
隨機森林建模
data.createOrReplaceTempView("data") // 字元型別轉換成數值 val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label" val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender" val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children" val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data") val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating") // 欄位轉換成特徵向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF.show(10, truncate = false) // 將資料分為訓練和測試集(30%進行測試) val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) // 索引標籤,將後設資料新增到標籤列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) //labelIndexer.transform(vecDF).show(10, truncate = false) // 自動識別分類的特徵,並對它們進行索引 // 具有大於5個不同的值的特徵被視為連續。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) //featureIndexer.transform(vecDF).show(10, truncate = false) // 訓練隨機森林模型 val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10) // 將索引標籤轉換回原始標籤 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // Train model. This also runs the indexers. val model = pipeline.fit(trainingDF) // 輸出隨機森林模型的全部引數值 model.stages(2).extractParamMap() // 作出預測 val predictions = model.transform(testDF) // Select example rows to display. predictions.select("predictedLabel", "label", "features").show(10, false) // 選擇(預測標籤,實際標籤),並計算測試誤差 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") val accuracy = evaluator.evaluate(predictions) println("Test Error = " + (1.0 - accuracy)) // 這裡的stages(2)中的“2”對應pipeline中的“rf”,將model強制轉換為RandomForestClassificationModel型別 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] println("Learned classification forest model:\n" + rfModel.toDebugString)
程式碼執行結果
vecDF.show(10, truncate = false) +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ |0.0 |1.0 |37.0|10.0 |0.0 |3.0 |18.0 |7.0 |4.0 |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]| |0.0 |0.0 |27.0|4.0 |0.0 |4.0 |14.0 |6.0 |4.0 |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] | |0.0 |0.0 |32.0|15.0 |1.0 |1.0 |12.0 |1.0 |4.0 |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]| |0.0 |1.0 |57.0|15.0 |1.0 |5.0 |18.0 |6.0 |5.0 |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]| |0.0 |1.0 |22.0|0.75 |0.0 |2.0 |17.0 |6.0 |3.0 |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]| |0.0 |0.0 |32.0|1.5 |0.0 |2.0 |17.0 |5.0 |5.0 |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] | |0.0 |0.0 |22.0|0.75 |0.0 |2.0 |12.0 |1.0 |3.0 |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]| |0.0 |1.0 |57.0|15.0 |1.0 |2.0 |14.0 |4.0 |4.0 |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]| |0.0 |0.0 |32.0|15.0 |1.0 |4.0 |16.0 |1.0 |2.0 |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]| |0.0 |1.0 |22.0|1.5 |0.0 |4.0 |14.0 |4.0 |5.0 |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] | +-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+ only showing top 10 rows // 將資料分為訓練和測試集(30%進行測試) val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) trainingDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] testDF: org.apache.spark.sql.Dataset[org.apache.spark.sql.Row] = [label: double, gender: double ... 8 more fields] // 索引標籤,將後設資料新增到標籤列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) labelIndexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_37df210602df //labelIndexer.transform(vecDF).show(10, truncate = false) // 自動識別分類的特徵,並對它們進行索引 // 具有大於5個不同的值的特徵被視為連續。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) featureIndexer: org.apache.spark.ml.feature.VectorIndexerModel = vecIdx_9595c228f520 //featureIndexer.transform(vecDF).show(10, truncate = false) // 訓練隨機森林模型 val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setNumTrees(10) rf: org.apache.spark.ml.classification.RandomForestClassifier = rfc_d0e7623d0b10 // 將索引標籤轉換回原始標籤 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) labelConverter: org.apache.spark.ml.feature.IndexToString = idxToStr_32d6938f2c94 // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) pipeline: org.apache.spark.ml.Pipeline = pipeline_97716da42fed // Train model. This also runs the indexers. val model = pipeline.fit(trainingDF) model: org.apache.spark.ml.PipelineModel = pipeline_97716da42fed // 輸出隨機森林模型的全部引數值 model.stages(2).extractParamMap() res10: org.apache.spark.ml.param.ParamMap = { rfc_0d830180d598-cacheNodeIds: false, rfc_0d830180d598-checkpointInterval: 10, rfc_0d830180d598-featureSubsetStrategy: auto, rfc_0d830180d598-featuresCol: indexedFeatures, rfc_0d830180d598-impurity: gini, rfc_0d830180d598-labelCol: indexedLabel, rfc_0d830180d598-maxBins: 32, rfc_0d830180d598-maxDepth: 5, rfc_0d830180d598-maxMemoryInMB: 256, rfc_0d830180d598-minInfoGain: 0.0, rfc_0d830180d598-minInstancesPerNode: 1, rfc_0d830180d598-predictionCol: prediction, rfc_0d830180d598-probabilityCol: probability, rfc_0d830180d598-rawPredictionCol: rawPrediction, rfc_0d830180d598-seed: 207336481, rfc_0d830180d598-subsamplingRate: 1.0 } // 作出預測 val predictions = model.transform(testDF) predictions: org.apache.spark.sql.DataFrame = [label: double, gender: double ... 14 more fields] predictions.select("predictedLabel", "label", "features").show(10,false) +--------------+-----+-------------------------------------+ |predictedLabel|label|features | +--------------+-----+-------------------------------------+ |0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,12.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.125,0.0,4.0,14.0,4.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]| |0.0 |0.0 |[0.0,22.0,0.417,0.0,4.0,14.0,5.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.417,1.0,3.0,14.0,3.0,5.0]| |0.0 |0.0 |[0.0,22.0,0.75,0.0,5.0,18.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,16.0,5.0,3.0] | |0.0 |0.0 |[0.0,22.0,1.5,0.0,4.0,17.0,5.0,5.0] | |0.0 |0.0 |[0.0,22.0,1.5,1.0,3.0,12.0,1.0,3.0] | +--------------+-----+-------------------------------------+ only showing top 10 rows // 選擇(預測標籤,實際標籤),並計算測試誤差 val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") evaluator: org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator = mcEval_13a195abc422 val accuracy = evaluator.evaluate(predictions) accuracy: Double = 0.7365591397849462 println("Test Error = " + (1.0 - accuracy)) Test Error = 0.26344086021505375 // 這裡的stages(2)中的“2”對應pipeline中的“rf”,將model強制轉換為RandomForestClassificationModel型別 val rfModel = model.stages(2).asInstanceOf[RandomForestClassificationModel] rfModel: org.apache.spark.ml.classification.RandomForestClassificationModel = RandomForestClassificationModel (uid=rfc_f7bb5e488533) with 10 trees println("Learned classification forest model:\n" + rfModel.toDebugString) Learned classification forest model: RandomForestClassificationModel (uid=rfc_f7bb5e488533) with 10 trees Tree 0 (weight 1.0): If (feature 2 <= 1.5) If (feature 5 <= 12.0) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) If (feature 2 <= 0.125) Predict: 0.0 Else (feature 2 > 0.125) Predict: 1.0 Else (feature 5 > 12.0) If (feature 0 in {0.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 2 <= 0.75) If (feature 4 in {0.0,1.0,2.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,2.0,4.0}) Predict: 0.0 Else (feature 2 > 0.75) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 1.0 Else (feature 2 > 1.5) If (feature 1 <= 42.0) If (feature 1 <= 27.0) If (feature 5 <= 16.0) If (feature 6 <= 5.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 1.0 Else (feature 5 > 16.0) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 0.0 Else (feature 1 > 27.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 2 <= 4.0) Predict: 1.0 Else (feature 2 > 4.0) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 1 > 42.0) If (feature 4 in {2.0,4.0}) Predict: 0.0 Else (feature 4 not in {2.0,4.0}) If (feature 4 in {0.0}) Predict: 1.0 Else (feature 4 not in {0.0}) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Tree 1 (weight 1.0): If (feature 7 in {0.0,2.0,4.0}) If (feature 7 in {0.0}) If (feature 1 <= 42.0) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 1 <= 17.5) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 1.0 Else (feature 1 > 17.5) If (feature 0 in {0.0}) If (feature 4 in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 6 <= 2.0) Predict: 1.0 Else (feature 6 > 2.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) If (feature 3 in {0.0}) If (feature 5 <= 14.0) If (feature 4 in {1.0,3.0}) Predict: 0.0 Else (feature 4 not in {1.0,3.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 1.0 Else (feature 5 > 14.0) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 4 in {0.0,2.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,2.0,3.0,4.0}) Predict: 1.0 Else (feature 3 not in {0.0}) If (feature 5 <= 12.0) If (feature 0 in {1.0}) Predict: 0.0 Else (feature 0 not in {1.0}) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) Predict: 0.0 Else (feature 5 > 12.0) If (feature 4 in {0.0,2.0,3.0,4.0}) If (feature 1 <= 47.0) Predict: 0.0 Else (feature 1 > 47.0) Predict: 1.0 Else (feature 4 not in {0.0,2.0,3.0,4.0}) If (feature 1 <= 22.0) Predict: 1.0 Else (feature 1 > 22.0) Predict: 0.0 Tree 2 (weight 1.0): If (feature 7 in {0.0}) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) If (feature 6 <= 5.0) If (feature 1 <= 42.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 5 <= 16.0) If (feature 7 in {1.0}) If (feature 6 <= 4.0) If (feature 2 <= 7.0) Predict: 0.0 Else (feature 2 > 7.0) Predict: 1.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 7 not in {1.0}) If (feature 3 in {1.0}) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 3 not in {1.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 5 > 16.0) If (feature 3 in {0.0}) If (feature 4 in {4.0}) Predict: 0.0 Else (feature 4 not in {4.0}) If (feature 5 <= 18.0) Predict: 0.0 Else (feature 5 > 18.0) Predict: 0.0 Else (feature 3 not in {0.0}) If (feature 4 in {0.0,3.0,4.0}) If (feature 7 in {2.0}) Predict: 0.0 Else (feature 7 not in {2.0}) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Tree 3 (weight 1.0): If (feature 3 in {0.0}) If (feature 7 in {3.0}) Predict: 0.0 Else (feature 7 not in {3.0}) If (feature 2 <= 10.0) If (feature 4 in {2.0,3.0,4.0}) If (feature 4 in {4.0}) Predict: 0.0 Else (feature 4 not in {4.0}) Predict: 0.0 Else (feature 4 not in {2.0,3.0,4.0}) If (feature 7 in {0.0,2.0,4.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) Predict: 1.0 Else (feature 2 > 10.0) Predict: 1.0 Else (feature 3 not in {0.0}) If (feature 6 <= 2.0) If (feature 5 <= 16.0) If (feature 7 in {0.0,1.0,2.0,4.0}) If (feature 4 in {0.0,1.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0,4.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0,4.0}) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 5 > 16.0) If (feature 7 in {0.0,1.0,3.0}) Predict: 0.0 Else (feature 7 not in {0.0,1.0,3.0}) Predict: 1.0 Else (feature 6 > 2.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 7 in {0.0,2.0,3.0,4.0}) If (feature 4 in {3.0,4.0}) Predict: 0.0 Else (feature 4 not in {3.0,4.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0,3.0,4.0}) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 1 <= 22.0) If (feature 5 <= 14.0) Predict: 1.0 Else (feature 5 > 14.0) Predict: 1.0 Else (feature 1 > 22.0) If (feature 6 <= 6.0) Predict: 0.0 Else (feature 6 > 6.0) Predict: 1.0 Tree 4 (weight 1.0): If (feature 7 in {0.0,2.0,4.0}) If (feature 7 in {0.0}) If (feature 6 <= 5.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) If (feature 4 in {2.0,4.0}) Predict: 1.0 Else (feature 4 not in {2.0,4.0}) Predict: 1.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 2 <= 1.5) If (feature 5 <= 12.0) If (feature 2 <= 0.125) Predict: 0.0 Else (feature 2 > 0.125) Predict: 0.0 Else (feature 5 > 12.0) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 2 > 1.5) If (feature 2 <= 7.0) If (feature 4 in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 4 not in {1.0,3.0,4.0}) Predict: 0.0 Else (feature 2 > 7.0) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) If (feature 5 <= 12.0) Predict: 0.0 Else (feature 5 > 12.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 1 <= 47.0) If (feature 1 <= 22.0) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 1 > 47.0) Predict: 1.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 1 <= 27.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 1 > 27.0) If (feature 5 <= 14.0) Predict: 1.0 Else (feature 5 > 14.0) Predict: 1.0 Tree 5 (weight 1.0): If (feature 7 in {0.0}) If (feature 1 <= 42.0) If (feature 6 <= 4.0) Predict: 1.0 Else (feature 6 > 4.0) If (feature 4 in {1.0}) Predict: 0.0 Else (feature 4 not in {1.0}) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 2 <= 1.5) If (feature 4 in {0.0,2.0,3.0}) If (feature 1 <= 22.0) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 1 > 22.0) Predict: 0.0 Else (feature 4 not in {0.0,2.0,3.0}) If (feature 1 <= 17.5) If (feature 6 <= 4.0) Predict: 1.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 1 > 17.5) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 2 > 1.5) If (feature 6 <= 5.0) If (feature 5 <= 17.0) If (feature 7 in {2.0,4.0}) Predict: 0.0 Else (feature 7 not in {2.0,4.0}) Predict: 0.0 Else (feature 5 > 17.0) If (feature 6 <= 1.0) Predict: 0.0 Else (feature 6 > 1.0) Predict: 0.0 Else (feature 6 > 5.0) If (feature 4 in {0.0,3.0,4.0}) If (feature 7 in {3.0,4.0}) Predict: 0.0 Else (feature 7 not in {3.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 6 <= 6.0) Predict: 0.0 Else (feature 6 > 6.0) Predict: 0.0 Tree 6 (weight 1.0): If (feature 4 in {0.0,3.0,4.0}) If (feature 5 <= 12.0) If (feature 7 in {1.0,2.0,3.0,4.0}) Predict: 0.0 Else (feature 7 not in {1.0,2.0,3.0,4.0}) If (feature 6 <= 3.0) Predict: 0.0 Else (feature 6 > 3.0) Predict: 1.0 Else (feature 5 > 12.0) If (feature 7 in {0.0,1.0,2.0}) If (feature 6 <= 1.0) If (feature 7 in {0.0,2.0}) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) Predict: 0.0 Else (feature 6 > 1.0) If (feature 1 <= 37.0) Predict: 1.0 Else (feature 1 > 37.0) Predict: 0.0 Else (feature 7 not in {0.0,1.0,2.0}) If (feature 1 <= 17.5) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) Predict: 1.0 Else (feature 1 > 17.5) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 4 not in {0.0,3.0,4.0}) If (feature 7 in {0.0,4.0}) If (feature 5 <= 12.0) If (feature 2 <= 0.125) Predict: 0.0 Else (feature 2 > 0.125) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 5 > 12.0) If (feature 7 in {0.0}) If (feature 1 <= 42.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 7 not in {0.0}) If (feature 2 <= 1.5) Predict: 0.0 Else (feature 2 > 1.5) Predict: 0.0 Else (feature 7 not in {0.0,4.0}) If (feature 6 <= 4.0) If (feature 7 in {3.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0 Else (feature 7 not in {3.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 1.0 Else (feature 6 > 4.0) If (feature 6 <= 6.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 1.0 Else (feature 6 > 6.0) If (feature 5 <= 18.0) Predict: 1.0 Else (feature 5 > 18.0) Predict: 0.0 Tree 7 (weight 1.0): If (feature 7 in {0.0,2.0,4.0}) If (feature 2 <= 1.5) If (feature 4 in {1.0,2.0,3.0}) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 4 not in {1.0,2.0,3.0}) If (feature 5 <= 14.0) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 1.0 Else (feature 5 > 14.0) Predict: 0.0 Else (feature 2 > 1.5) If (feature 7 in {0.0,2.0}) If (feature 4 in {1.0,3.0,4.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 4 not in {1.0,3.0,4.0}) If (feature 6 <= 5.0) Predict: 1.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0}) If (feature 4 in {0.0,1.0,3.0}) If (feature 1 <= 42.0) Predict: 0.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 4 not in {0.0,1.0,3.0}) If (feature 5 <= 16.0) Predict: 0.0 Else (feature 5 > 16.0) Predict: 0.0 Else (feature 7 not in {0.0,2.0,4.0}) If (feature 2 <= 0.75) Predict: 0.0 Else (feature 2 > 0.75) If (feature 4 in {4.0}) If (feature 6 <= 5.0) If (feature 1 <= 37.0) Predict: 1.0 Else (feature 1 > 37.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 4 not in {4.0}) If (feature 5 <= 12.0) If (feature 1 <= 27.0) Predict: 0.0 Else (feature 1 > 27.0) Predict: 0.0 Else (feature 5 > 12.0) If (feature 7 in {1.0}) Predict: 1.0 Else (feature 7 not in {1.0}) Predict: 0.0 Tree 8 (weight 1.0): If (feature 5 <= 16.0) If (feature 4 in {0.0,1.0}) If (feature 0 in {0.0}) If (feature 2 <= 0.75) If (feature 1 <= 17.5) Predict: 1.0 Else (feature 1 > 17.5) Predict: 0.0 Else (feature 2 > 0.75) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 0 not in {0.0}) If (feature 5 <= 12.0) Predict: 1.0 Else (feature 5 > 12.0) If (feature 7 in {2.0,4.0}) Predict: 0.0 Else (feature 7 not in {2.0,4.0}) Predict: 0.0 Else (feature 4 not in {0.0,1.0}) If (feature 7 in {0.0,2.0,3.0,4.0}) If (feature 1 <= 22.0) If (feature 6 <= 3.0) Predict: 0.0 Else (feature 6 > 3.0) Predict: 0.0 Else (feature 1 > 22.0) If (feature 6 <= 6.0) Predict: 0.0 Else (feature 6 > 6.0) Predict: 1.0 Else (feature 7 not in {0.0,2.0,3.0,4.0}) If (feature 1 <= 42.0) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 5 > 16.0) If (feature 5 <= 18.0) If (feature 4 in {3.0}) If (feature 7 in {1.0,2.0,3.0}) Predict: 0.0 Else (feature 7 not in {1.0,2.0,3.0}) If (feature 6 <= 5.0) Predict: 0.0 Else (feature 6 > 5.0) Predict: 0.0 Else (feature 4 not in {3.0}) If (feature 2 <= 0.75) Predict: 0.0 Else (feature 2 > 0.75) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 1.0 Else (feature 5 > 18.0) If (feature 1 <= 27.0) If (feature 7 in {3.0}) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 1.0 Else (feature 7 not in {3.0}) If (feature 2 <= 4.0) Predict: 0.0 Else (feature 2 > 4.0) Predict: 1.0 Else (feature 1 > 27.0) If (feature 6 <= 5.0) If (feature 6 <= 4.0) Predict: 0.0 Else (feature 6 > 4.0) Predict: 0.0 Else (feature 6 > 5.0) If (feature 4 in {3.0,4.0}) Predict: 0.0 Else (feature 4 not in {3.0,4.0}) Predict: 0.0 Tree 9 (weight 1.0): If (feature 5 <= 16.0) If (feature 6 <= 2.0) If (feature 1 <= 42.0) If (feature 6 <= 1.0) If (feature 5 <= 9.0) Predict: 1.0 Else (feature 5 > 9.0) Predict: 0.0 Else (feature 6 > 1.0) If (feature 1 <= 27.0) Predict: 0.0 Else (feature 1 > 27.0) Predict: 1.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 6 > 2.0) If (feature 1 <= 27.0) If (feature 5 <= 14.0) If (feature 6 <= 3.0) Predict: 0.0 Else (feature 6 > 3.0) Predict: 0.0 Else (feature 5 > 14.0) Predict: 0.0 Else (feature 1 > 27.0) If (feature 4 in {1.0,2.0,4.0}) If (feature 5 <= 9.0) Predict: 0.0 Else (feature 5 > 9.0) Predict: 0.0 Else (feature 4 not in {1.0,2.0,4.0}) If (feature 7 in {2.0,3.0,4.0}) Predict: 0.0 Else (feature 7 not in {2.0,3.0,4.0}) Predict: 1.0 Else (feature 5 > 16.0) If (feature 6 <= 4.0) If (feature 4 in {3.0}) Predict: 0.0 Else (feature 4 not in {3.0}) If (feature 1 <= 42.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 1 > 42.0) Predict: 1.0 Else (feature 6 > 4.0) If (feature 4 in {3.0,4.0}) If (feature 1 <= 37.0) If (feature 3 in {0.0}) Predict: 0.0 Else (feature 3 not in {0.0}) Predict: 0.0 Else (feature 1 > 37.0) If (feature 1 <= 42.0) Predict: 0.0 Else (feature 1 > 42.0) Predict: 0.0 Else (feature 4 not in {3.0,4.0}) If (feature 4 in {0.0,2.0}) If (feature 7 in {0.0,1.0,2.0}) Predict: 1.0 Else (feature 7 not in {0.0,1.0,2.0}) Predict: 1.0 Else (feature 4 not in {0.0,2.0}) If (feature 0 in {0.0}) Predict: 0.0 Else (feature 0 not in {0.0}) Predict: 0.0
隨機森林模型調優
// 欄位轉換成特徵向量 val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features") val vecDF: DataFrame = assembler.transform(dataLabelDF) vecDF.show(10, truncate = false) // 將資料分為訓練和測試集(30%進行測試) val Array(trainingDF, testDF) = vecDF.randomSplit(Array(0.7, 0.3)) // 索引標籤,將後設資料新增到標籤列中 val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF) //labelIndexer.transform(vecDF).show(10, truncate = false) // 自動識別分類的特徵,並對它們進行索引 // 具有大於5個不同的值的特徵被視為連續。 val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF) //featureIndexer.transform(vecDF).show(10, truncate = false) // 訓練隨機森林模型 val rf = new RandomForestClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures") // 將索引標籤轉換回原始標籤 val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels) // Chain indexers and forest in a Pipeline. val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, rf, labelConverter)) // 設定引數網格 //impurity 不純度 //maxBins 離散化"連續特徵"的最大劃分數 //maxDepth 樹的最大深度 //minInfoGain 一個節點分裂的最小資訊增益,值為[0,1] //minInstancesPerNode 每個節點包含的最小樣本數 >=1 //numTrees 樹的數量 //featureSubsetStrategy // 在每個樹節點處分割的特徵數,引數值比較多,詳細的請參考官方文件 //SubsamplingRate(1.0) 給每棵樹分配“學習資料”的比例,範圍(0, 1] //maxMemoryInMB 如果太小,則每次迭代將拆分1個節點,其聚合可能超過此大小。 //checkpointInterval 設定檢查點間隔(> = 1)或禁用檢查點(-1)。 例如 10意味著,每10次迭代,快取將獲得檢查點。 //cacheNodeIds 如果為false,則演算法將樹傳遞給執行器以將例項與節點匹配。 如果為true,演算法將快取每個例項的節點ID。 快取可以加速更大深度的樹的訓練。 使用者可以通過設定checkpointInterval來設定檢查或禁用快取的頻率。(default = false) //seed 種子 val paramGrid = new ParamGridBuilder() .addGrid(rf.impurity, Array("entropy", "gini")) .addGrid(rf.maxBins, Array(32, 64)) .addGrid(rf.maxDepth, Array(5, 7, 10)) .addGrid(rf.minInfoGain, Array(0, 0.5, 1)) .addGrid(rf.minInstancesPerNode, Array(10, 20)) .addGrid(rf.numTrees, Array(20, 50)) .addGrid(rf.featureSubsetStrategy, Array("auto", "sqrt")) .addGrid(rf.subsamplingRate, Array(0.8, 1)) .addGrid(rf.maxMemoryInMB, Array(256, 512)) .addGrid(rf.checkpointInterval, Array(10, 20)) .addGrid(rf.cacheNodeIds, Array(false, true)) .addGrid(rf.seed, Array(123456L, 111L)) .build() // 選擇(預測標籤,實際標籤),並計算測試誤差。indexedLabel與prediction都是索引化的,因此可以直接比較 val classEvaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy") // 設定交叉驗證 val cv = new CrossValidator().setEstimator(pipeline).setEvaluator(classEvaluator).setEstimatorParamMaps(paramGrid).setNumFolds(5) // 執行交叉驗證,並選擇出最好的引數集合 val cvModel = cv.fit(trainingDF) // 檢視全部引數 cvModel.extractParamMap() // cvModel.avgMetrics.length=cvModel.getEstimatorParamMaps.length // cvModel.avgMetrics與cvModel.getEstimatorParamMaps中的元素一一對應 cvModel.avgMetrics.length cvModel.avgMetrics // 引數對應的平均度量 cvModel.getEstimatorParamMaps.length cvModel.getEstimatorParamMaps // 引數組合的集合 cvModel.getEvaluator.extractParamMap() // 評估的引數 cvModel.getEvaluator.isLargerBetter // 評估的度量值是大的好,還是小的好 ,根據評估度量,系統會自動識別 cvModel.getNumFolds // 交叉驗證的折數 //################################ // 測試模型 val predictDF: DataFrame = cvModel.transform(testDF).selectExpr( //"race","poverty","smoke","alcohol","agemth","ybirth","yschool","pc3mth", "features", "predictedLabel", "label", "features") predictDF.show(20, false)