Spark2 ML包之決策樹分類Decision tree classifier詳細解說

智慧先行者發表於2016-11-29

所用資料來源,請參考本人部落格http://www.cnblogs.com/wwxbi/p/6063613.html

1.匯入包

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.sql.functions._

import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.classification.DecisionTreeClassificationModel
import org.apache.spark.ml.classification.DecisionTreeClassifier
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.IndexToString
import org.apache.spark.ml.feature.VectorIndexer
import org.apache.spark.ml.feature.VectorSlicer

 

2.載入資料來源

val spark = SparkSession.builder().appName("Spark decision tree classifier").config("spark.some.config.option", "some-value").getOrCreate()

// For implicit conversions like converting RDDs to DataFrames
import spark.implicits._

// 這裡僅僅是示例資料,資料來源,請參考本人部落格http://www.cnblogs.com/wwxbi/p/6063613.html
val dataList: List[(Double, String, Double, Double, String, Double, Double, Double, Double)] = List(
  (0, "male", 37, 10, "no", 3, 18, 7, 4),
  (0, "female", 27, 4, "no", 4, 14, 6, 4),
  (0, "female", 32, 15, "yes", 1, 12, 1, 4),
  (0, "male", 57, 15, "yes", 5, 18, 6, 5),
  (0, "male", 22, 0.75, "no", 2, 17, 6, 3),
  (0, "female", 32, 1.5, "no", 2, 17, 5, 5))

val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

data.createOrReplaceTempView("data")

// 字元型別轉換成數值
val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label"
val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender"
val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children"

val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data")

 

 3.建立決策樹模型

val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

// 欄位轉換成特徵向量
val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")
val vecDF: DataFrame = assembler.transform(dataLabelDF)
vecDF.show(10, truncate = false)

// 索引標籤,將後設資料新增到標籤列中
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)
labelIndexer.transform(vecDF).show(10, truncate = false)

// 自動識別分類的特徵,並對它們進行索引
// 具有大於5個不同的值的特徵被視為連續。
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF)
featureIndexer.transform(vecDF).show(10, truncate = false)

// 將資料分為訓練和測試集(30%進行測試)
val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3))

// 訓練決策樹模型
val dt = new DecisionTreeClassifier()
.setLabelCol("indexedLabel")
.setFeaturesCol("indexedFeatures")
.setImpurity("entropy") // 不純度
.setMaxBins(100) // 離散化"連續特徵"的最大劃分數
.setMaxDepth(5) // 樹的最大深度
.setMinInfoGain(0.01) //一個節點分裂的最小資訊增益,值為[0,1]
.setMinInstancesPerNode(10) //每個節點包含的最小樣本數 
.setSeed(123456)

// 將索引標籤轉換回原始標籤
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// 作出預測
val predictions = model.transform(testData)

// 選擇幾個示例行展示
predictions.select("predictedLabel", "label", "features").show(10, truncate = false)

// 選擇(預測標籤,實際標籤),並計算測試誤差。
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
println("Test Error = " + (1.0 - accuracy))

// 這裡的stages(2)中的“2”對應pipeline中的“dt”,將model強制轉換為DecisionTreeClassificationModel型別
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
treeModel.getLabelCol
treeModel.getFeaturesCol
treeModel.featureImportances
treeModel.getPredictionCol
treeModel.getProbabilityCol

treeModel.numClasses
treeModel.numFeatures
treeModel.depth
treeModel.numNodes

treeModel.getImpurity
treeModel.getMaxBins
treeModel.getMaxDepth
treeModel.getMaxMemoryInMB
treeModel.getMinInfoGain
treeModel.getMinInstancesPerNode

println("Learned classification tree model:\n" + treeModel.toDebugString)

 

4.程式碼執行結果

val data = dataList.toDF("affairs", "gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

data.show(10, truncate = false)
+-------+------+----+------------+--------+-------------+---------+----------+------+
|affairs|gender|age |yearsmarried|children|religiousness|education|occupation|rating|
+-------+------+----+------------+--------+-------------+---------+----------+------+
|0.0    |male  |37.0|10.0        |no      |3.0          |18.0     |7.0       |4.0   |
|0.0    |female|27.0|4.0         |no      |4.0          |14.0     |6.0       |4.0   |
|0.0    |female|32.0|15.0        |yes     |1.0          |12.0     |1.0       |4.0   |
|0.0    |male  |57.0|15.0        |yes     |5.0          |18.0     |6.0       |5.0   |
|0.0    |male  |22.0|0.75        |no      |2.0          |17.0     |6.0       |3.0   |
|0.0    |female|32.0|1.5         |no      |2.0          |17.0     |5.0       |5.0   |
|0.0    |female|22.0|0.75        |no      |2.0          |12.0     |1.0       |3.0   |
|0.0    |male  |57.0|15.0        |yes     |2.0          |14.0     |4.0       |4.0   |
|0.0    |female|32.0|15.0        |yes     |4.0          |16.0     |1.0       |2.0   |
|0.0    |male  |22.0|1.5         |no      |4.0          |14.0     |4.0       |5.0   |
+-------+------+----+------------+--------+-------------+---------+----------+------+
only showing top 10 rows


data.createOrReplaceTempView("data")

// 字元型別轉換成數值
val labelWhere = "case when affairs=0 then 0 else cast(1 as double) end as label"
val genderWhere = "case when gender='female' then 0 else cast(1 as double) end as gender"
val childrenWhere = "case when children='no' then 0 else cast(1 as double) end as children"

val dataLabelDF = spark.sql(s"select $labelWhere, $genderWhere,age,yearsmarried,$childrenWhere,religiousness,education,occupation,rating from data")

dataLabelDF.show(10, truncate = false)
+-----+------+----+------------+--------+-------------+---------+----------+------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|
+-----+------+----+------------+--------+-------------+---------+----------+------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |
+-----+------+----+------------+--------+-------------+---------+----------+------+
only showing top 10 rows


val featuresArray = Array("gender", "age", "yearsmarried", "children", "religiousness", "education", "occupation", "rating")

// 欄位轉換成特徵向量
val assembler = new VectorAssembler().setInputCols(featuresArray).setOutputCol("features")
val vecDF: DataFrame = assembler.transform(dataLabelDF)
vecDF.show(10, truncate = false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features                            |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+
only showing top 10 rows


// 索引標籤,將後設資料新增到標籤列中
val labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(vecDF)
labelIndexer.transform(vecDF).show(10, truncate = false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features                            |indexedLabel|
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|0.0         |
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |0.0         |
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|0.0         |
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|0.0         |
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|0.0         |
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |0.0         |
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|0.0         |
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|0.0         |
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|0.0         |
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |0.0         |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------+
only showing top 10 rows


// 自動識別分類的特徵,並對它們進行索引
// 具有大於5個不同的值的特徵被視為連續。
val featureIndexer = new VectorIndexer().setInputCol("features").setOutputCol("indexedFeatures").setMaxCategories(5).fit(vecDF)
featureIndexer.transform(vecDF).show(10, truncate = false)
featureIndexer.transform(vecDF).show(10, truncate = false)
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------------------------------+
|label|gender|age |yearsmarried|children|religiousness|education|occupation|rating|features                            |indexedFeatures                     |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------------------------------+
|0.0  |1.0   |37.0|10.0        |0.0     |3.0          |18.0     |7.0       |4.0   |[1.0,37.0,10.0,0.0,3.0,18.0,7.0,4.0]|[1.0,37.0,10.0,0.0,2.0,18.0,7.0,3.0]|
|0.0  |0.0   |27.0|4.0         |0.0     |4.0          |14.0     |6.0       |4.0   |[0.0,27.0,4.0,0.0,4.0,14.0,6.0,4.0] |[0.0,27.0,4.0,0.0,3.0,14.0,6.0,3.0] |
|0.0  |0.0   |32.0|15.0        |1.0     |1.0          |12.0     |1.0       |4.0   |[0.0,32.0,15.0,1.0,1.0,12.0,1.0,4.0]|[0.0,32.0,15.0,1.0,0.0,12.0,1.0,3.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |5.0          |18.0     |6.0       |5.0   |[1.0,57.0,15.0,1.0,5.0,18.0,6.0,5.0]|[1.0,57.0,15.0,1.0,4.0,18.0,6.0,4.0]|
|0.0  |1.0   |22.0|0.75        |0.0     |2.0          |17.0     |6.0       |3.0   |[1.0,22.0,0.75,0.0,2.0,17.0,6.0,3.0]|[1.0,22.0,0.75,0.0,1.0,17.0,6.0,2.0]|
|0.0  |0.0   |32.0|1.5         |0.0     |2.0          |17.0     |5.0       |5.0   |[0.0,32.0,1.5,0.0,2.0,17.0,5.0,5.0] |[0.0,32.0,1.5,0.0,1.0,17.0,5.0,4.0] |
|0.0  |0.0   |22.0|0.75        |0.0     |2.0          |12.0     |1.0       |3.0   |[0.0,22.0,0.75,0.0,2.0,12.0,1.0,3.0]|[0.0,22.0,0.75,0.0,1.0,12.0,1.0,2.0]|
|0.0  |1.0   |57.0|15.0        |1.0     |2.0          |14.0     |4.0       |4.0   |[1.0,57.0,15.0,1.0,2.0,14.0,4.0,4.0]|[1.0,57.0,15.0,1.0,1.0,14.0,4.0,3.0]|
|0.0  |0.0   |32.0|15.0        |1.0     |4.0          |16.0     |1.0       |2.0   |[0.0,32.0,15.0,1.0,4.0,16.0,1.0,2.0]|[0.0,32.0,15.0,1.0,3.0,16.0,1.0,1.0]|
|0.0  |1.0   |22.0|1.5         |0.0     |4.0          |14.0     |4.0       |5.0   |[1.0,22.0,1.5,0.0,4.0,14.0,4.0,5.0] |[1.0,22.0,1.5,0.0,3.0,14.0,4.0,4.0] |
+-----+------+----+------------+--------+-------------+---------+----------+------+------------------------------------+------------------------------------+
only showing top 10 rows


// 將資料分為訓練和測試集(30%進行測試)
val Array(trainingData, testData) = vecDF.randomSplit(Array(0.7, 0.3))

// 訓練決策樹模型
val dt = new DecisionTreeClassifier().setLabelCol("indexedLabel").setFeaturesCol("indexedFeatures").setImpurity("entropy").setMaxBins(100).setMaxDepth(5).setMinInfoGain(0.01).setMinInstancesPerNode(10).setSeed(123456)
//.setLabelCol("indexedLabel")
//.setFeaturesCol("indexedFeatures")
//.setImpurity("entropy") // 不純度
//.setMaxBins(100) // 離散化"連續特徵"的最大劃分數
//.setMaxDepth(5) // 樹的最大深度
//.setMinInfoGain(0.01) //一個節點分裂的最小資訊增益,值為[0,1]
//.setMinInstancesPerNode(10) //每個節點包含的最小樣本數 
//.setSeed(123456)

// 將索引標籤轉換回原始標籤
val labelConverter = new IndexToString().setInputCol("prediction").setOutputCol("predictedLabel").setLabels(labelIndexer.labels)

// Chain indexers and tree in a Pipeline.
val pipeline = new Pipeline().setStages(Array(labelIndexer, featureIndexer, dt, labelConverter))

// Train model. This also runs the indexers.
val model = pipeline.fit(trainingData)

// 作出預測
val predictions = model.transform(testData)

// 選擇幾個示例行展示
predictions.select("predictedLabel", "label", "features").show(10, truncate = false)
+--------------+-----+-------------------------------------+
|predictedLabel|label|features                             |
+--------------+-----+-------------------------------------+
|0.0           |0.0  |[0.0,22.0,0.125,0.0,2.0,14.0,4.0,5.0]|
|0.0           |0.0  |[0.0,22.0,0.417,0.0,1.0,17.0,6.0,4.0]|
|0.0           |0.0  |[0.0,22.0,0.75,0.0,2.0,18.0,6.0,5.0] |
|0.0           |0.0  |[0.0,22.0,0.75,0.0,3.0,16.0,1.0,5.0] |
|0.0           |0.0  |[0.0,22.0,0.75,0.0,4.0,16.0,1.0,5.0] |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,1.0,14.0,1.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,2.0,14.0,1.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,2.0,16.0,5.0,5.0]  |
|0.0           |0.0  |[0.0,22.0,1.5,0.0,2.0,17.0,5.0,4.0]  |
+--------------+-----+-------------------------------------+


// 選擇(預測標籤,實際標籤),並計算測試誤差。
val evaluator = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setPredictionCol("prediction").setMetricName("accuracy")
val accuracy = evaluator.evaluate(predictions)
accuracy: Double = 0.6972972972972973

println("Test Error = " + (1.0 - accuracy))
Test Error = 0.3027027027027027


// 這裡的stages(2)中的“2”對應pipeline中的“dt”,將model強制轉換為DecisionTreeClassificationModel型別
val treeModel = model.stages(2).asInstanceOf[DecisionTreeClassificationModel]
DecisionTreeClassificationModel (uid=dtc_b950f91d35f8) of depth 5 with 43 nodes

treeModel.getLabelCol
String = indexedLabel

treeModel.getFeaturesCol
String = indexedFeatures

treeModel.featureImportances
Vector = (8,[0,1,2,4,5,6,7],[0.012972759843658999,0.1075317063921102,0.11654682273543511,0.17869552275855793,0.07532637852021348,0.27109893303920024,0.237827
876710824])
treeModel.getPredictionCol
String = prediction

treeModel.getProbabilityCol
String = probability

treeModel.numClasses
Int = 2

treeModel.numFeatures
Int = 8

treeModel.depth
Int = 5

treeModel.numNodes
Int = 43

treeModel.getImpurity
String = entropy

treeModel.getMaxBins
Int = 100

treeModel.getMaxDepth
Int = 5

treeModel.getMaxMemoryInMB
Int = 256

treeModel.getMinInfoGain
Double = 0.01

treeModel.getMinInstancesPerNode
Int = 10

// 檢視決策樹
println("Learned classification tree model:\n" + treeModel.toDebugString)
Learned classification tree model:
DecisionTreeClassificationModel (uid=dtc_b950f91d35f8) of depth 5 with 43 nodes
// 例如“feature 7 in {0.0,1.0,2.0}”中的“{0.0,1.0,2.0}”
// 具體解釋請參考本人部落格http://www.cnblogs.com/wwxbi/p/6125493.html“VectorIndexer自動識別分類的特徵,並對它們進行索引”
  If (feature 7 in {0.0,1.0,2.0})
   If (feature 7 in {0.0,2.0})
    If (feature 4 in {0.0,4.0})
     Predict: 1.0
    Else (feature 4 not in {0.0,4.0})
     If (feature 1 <= 32.0)
      If (feature 1 <= 27.0)
       Predict: 0.0
      Else (feature 1 > 27.0)
       Predict: 1.0
     Else (feature 1 > 32.0)
      If (feature 5 <= 16.0)
       Predict: 0.0
      Else (feature 5 > 16.0)
       Predict: 0.0
   Else (feature 7 not in {0.0,2.0})
    If (feature 4 in {0.0,1.0,3.0,4.0})
     If (feature 0 in {0.0})
      If (feature 2 <= 7.0)
       Predict: 0.0
      Else (feature 2 > 7.0)
       Predict: 0.0
     Else (feature 0 not in {0.0})
      Predict: 0.0
    Else (feature 4 not in {0.0,1.0,3.0,4.0})
     Predict: 1.0
  Else (feature 7 not in {0.0,1.0,2.0})
   If (feature 2 <= 4.0)
    If (feature 6 <= 3.0)
     If (feature 6 <= 1.0)
      Predict: 0.0
     Else (feature 6 > 1.0)
      Predict: 0.0
    Else (feature 6 > 3.0)
     If (feature 5 <= 16.0)
      If (feature 2 <= 0.75)
       Predict: 0.0
      Else (feature 2 > 0.75)
       Predict: 0.0
     Else (feature 5 > 16.0)
      If (feature 7 in {4.0})
       Predict: 0.0
      Else (feature 7 not in {4.0})
       Predict: 0.0
   Else (feature 2 > 4.0)
    If (feature 6 <= 3.0)
     If (feature 4 in {0.0,1.0,2.0})
      Predict: 0.0
     Else (feature 4 not in {0.0,1.0,2.0})
      If (feature 7 in {4.0})
       Predict: 0.0
      Else (feature 7 not in {4.0})
       Predict: 0.0
    Else (feature 6 > 3.0)
     If (feature 4 in {0.0,2.0,3.0,4.0})
      If (feature 6 <= 4.0)
       Predict: 0.0
      Else (feature 6 > 4.0)
       Predict: 0.0
     Else (feature 4 not in {0.0,2.0,3.0,4.0})
      If (feature 1 <= 37.0)
       Predict: 1.0
      Else (feature 1 > 37.0)
       Predict: 0.0

 

相關文章