Spark2 oneHot編碼--標準化--主成分--聚類

智慧先行者發表於2016-11-03

1.匯入包

import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.Row
import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.Column
import org.apache.spark.sql.DataFrameReader
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.functions._
import org.apache.spark.sql.DataFrameStatFunctions
import org.apache.spark.ml.linalg.Vectors
import org.apache.spark.ml.feature.StringIndexer
import org.apache.spark.ml.feature.OneHotEncoder
import org.apache.spark.ml.feature.VectorAssembler
import org.apache.spark.ml.feature.MinMaxScaler
import org.apache.spark.ml.feature.StandardScaler
import org.apache.spark.ml.feature.PCA
import org.apache.spark.ml.clustering.KMeans

 

2.匯入資料

   
val spark = SparkSession.builder().appName("Spark SQL basic example").config("spark.some.config.option", "some-value").getOrCreate() 
   
// For implicit conversions like converting RDDs to DataFrames 
import spark.implicits._
   
val data: DataFrame = spark.read.format("csv").option("header", true).load("hdfs://ns1/datafile/wangxiao/Affairs.csv") 
data: org.apache.spark.sql.DataFrame = [affairs: string, gender: string ... 7 more fields] 
   
data.cache 
res0: data.type = [affairs: string, gender: string ... 7 more fields] 
  
data.limit(10).show() 
+-------+------+---+------------+--------+-------------+---------+----------+------+ 
|affairs|gender|age|yearsmarried|children|religiousness|education|occupation|rating| 
+-------+------+---+------------+--------+-------------+---------+----------+------+ 
|      0|  male| 37|          10|      no|            3|       18|         7|     4| 
|      0|female| 27|           4|      no|            4|       14|         6|     4| 
|      0|female| 32|          15|     yes|            1|       12|         1|     4| 
|      0|  male| 57|          15|     yes|            5|       18|         6|     5| 
|      0|  male| 22|        0.75|      no|            2|       17|         6|     3| 
|      0|female| 32|         1.5|      no|            2|       17|         5|     5| 
|      0|female| 22|        0.75|      no|            2|       12|         1|     3| 
|      0|  male| 57|          15|     yes|            2|       14|         4|     4| 
|      0|female| 32|          15|     yes|            4|       16|         1|     2| 
|      0|  male| 22|         1.5|      no|            4|       14|         4|     5| 
+-------+------+---+------------+--------+-------------+---------+----------+------+ 
   
// 轉換字元型別,將Double和String的欄位分開放 
val data1 = data.select( 
     |   data("affairs").cast("Double"), 
     |   data("age").cast("Double"), 
     |   data("yearsmarried").cast("Double"), 
     |   data("religiousness").cast("Double"), 
     |   data("education").cast("Double"), 
     |   data("occupation").cast("Double"), 
     |   data("rating").cast("Double"), 
     |   data("gender").cast("String"), 
     |   data("children").cast("String")) 
data1: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 7 more fields] 
   
data1.printSchema() 
root 
 |-- affairs: double (nullable = true) 
 |-- age: double (nullable = true) 
 |-- yearsmarried: double (nullable = true) 
 |-- religiousness: double (nullable = true) 
 |-- education: double (nullable = true) 
 |-- occupation: double (nullable = true) 
 |-- rating: double (nullable = true) 
 |-- gender: string (nullable = true) 
 |-- children: string (nullable = true) 
   
   
data1.limit(10).show 
+-------+----+------------+-------------+---------+----------+------+------+--------+ 
|affairs| age|yearsmarried|religiousness|education|occupation|rating|gender|children| 
+-------+----+------------+-------------+---------+----------+------+------+--------+ 
|    0.0|37.0|        10.0|          3.0|     18.0|       7.0|   4.0|  male|      no| 
|    0.0|27.0|         4.0|          4.0|     14.0|       6.0|   4.0|female|      no| 
|    0.0|32.0|        15.0|          1.0|     12.0|       1.0|   4.0|female|     yes| 
|    0.0|57.0|        15.0|          5.0|     18.0|       6.0|   5.0|  male|     yes| 
|    0.0|22.0|        0.75|          2.0|     17.0|       6.0|   3.0|  male|      no| 
|    0.0|32.0|         1.5|          2.0|     17.0|       5.0|   5.0|female|      no| 
|    0.0|22.0|        0.75|          2.0|     12.0|       1.0|   3.0|female|      no| 
|    0.0|57.0|        15.0|          2.0|     14.0|       4.0|   4.0|  male|     yes| 
|    0.0|32.0|        15.0|          4.0|     16.0|       1.0|   2.0|female|     yes| 
|    0.0|22.0|         1.5|          4.0|     14.0|       4.0|   5.0|  male|      no| 
+-------+----+------------+-------------+---------+----------+------+------+--------+ 
   
val dataDF = data1
dataDF: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 7 more fields] 
   
dataDF.cache() 
res4: dataDF.type = [affairs: double, age: double ... 7 more fields] 

 

3.字元轉換成數字索引,OneHot編碼,注意setDropLast設定為false

字元轉換成數字索引 
val indexer = new StringIndexer().setInputCol("gender").setOutputCol("genderIndex").fit(dataDF) 
indexer: org.apache.spark.ml.feature.StringIndexerModel = strIdx_27dba613193a 
   
val indexed = indexer.transform(dataDF) 
indexed: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 8 more fields] 
   
// OneHot編碼,注意setDropLast設定為false 
val encoder = new OneHotEncoder().setInputCol("genderIndex").setOutputCol("genderVec").setDropLast(false) 
encoder: org.apache.spark.ml.feature.OneHotEncoder = oneHot_155a53de3aef 
   
val encoded = encoder.transform(indexed) 
encoded: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 9 more fields] 
   
encoded.show() 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+ 
|affairs| age|yearsmarried|religiousness|education|occupation|rating|gender|children|genderIndex|    genderVec| 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+ 
|    0.0|37.0|        10.0|          3.0|     18.0|       7.0|   4.0|  male|      no|        1.0|(2,[1],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     14.0|       6.0|   4.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|32.0|        15.0|          1.0|     12.0|       1.0|   4.0|female|     yes|        0.0|(2,[0],[1.0])| 
|    0.0|57.0|        15.0|          5.0|     18.0|       6.0|   5.0|  male|     yes|        1.0|(2,[1],[1.0])| 
|    0.0|22.0|        0.75|          2.0|     17.0|       6.0|   3.0|  male|      no|        1.0|(2,[1],[1.0])| 
|    0.0|32.0|         1.5|          2.0|     17.0|       5.0|   5.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|22.0|        0.75|          2.0|     12.0|       1.0|   3.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|57.0|        15.0|          2.0|     14.0|       4.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])| 
|    0.0|32.0|        15.0|          4.0|     16.0|       1.0|   2.0|female|     yes|        0.0|(2,[0],[1.0])| 
|    0.0|22.0|         1.5|          4.0|     14.0|       4.0|   5.0|  male|      no|        1.0|(2,[1],[1.0])| 
|    0.0|37.0|        15.0|          2.0|     20.0|       7.0|   2.0|  male|     yes|        1.0|(2,[1],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     18.0|       6.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])| 
|    0.0|47.0|        15.0|          5.0|     17.0|       6.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])| 
|    0.0|22.0|         1.5|          2.0|     17.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     14.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|37.0|        15.0|          1.0|     17.0|       5.0|   5.0|female|     yes|        0.0|(2,[0],[1.0])| 
|    0.0|37.0|        15.0|          2.0|     18.0|       4.0|   3.0|female|     yes|        0.0|(2,[0],[1.0])| 
|    0.0|22.0|        0.75|          3.0|     16.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|22.0|         1.5|          2.0|     16.0|       5.0|   5.0|female|      no|        0.0|(2,[0],[1.0])| 
|    0.0|27.0|        10.0|          2.0|     14.0|       1.0|   5.0|female|     yes|        0.0|(2,[0],[1.0])| 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+ 
only showing top 20 rows 
  
val indexer1 = new StringIndexer().setInputCol("children").setOutputCol("childrenIndex").fit(encoded) 
indexer1: org.apache.spark.ml.feature.StringIndexerModel = strIdx_55db099c07b7
   
val indexed1 = indexer1.transform(encoded) 
indexed1: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 10 more fields] 
   
val encoder1 = new OneHotEncoder().setInputCol("childrenIndex").setOutputCol("childrenVec").setDropLast(false) 
   
val encoded1 = encoder1.transform(indexed1) 
encoded1: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 11 more fields] 
   
encoded1.show() 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+-------------+-------------+ 
|affairs| age|yearsmarried|religiousness|education|occupation|rating|gender|children|genderIndex|    genderVec|childrenIndex|  childrenVec| 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+-------------+-------------+ 
|    0.0|37.0|        10.0|          3.0|     18.0|       7.0|   4.0|  male|      no|        1.0|(2,[1],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     14.0|       6.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|32.0|        15.0|          1.0|     12.0|       1.0|   4.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|57.0|        15.0|          5.0|     18.0|       6.0|   5.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|        0.75|          2.0|     17.0|       6.0|   3.0|  male|      no|        1.0|(2,[1],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|32.0|         1.5|          2.0|     17.0|       5.0|   5.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|22.0|        0.75|          2.0|     12.0|       1.0|   3.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|57.0|        15.0|          2.0|     14.0|       4.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|32.0|        15.0|          4.0|     16.0|       1.0|   2.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|         1.5|          4.0|     14.0|       4.0|   5.0|  male|      no|        1.0|(2,[1],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|37.0|        15.0|          2.0|     20.0|       7.0|   2.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     18.0|       6.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|47.0|        15.0|          5.0|     17.0|       6.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|         1.5|          2.0|     17.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     14.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|37.0|        15.0|          1.0|     17.0|       5.0|   5.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|37.0|        15.0|          2.0|     18.0|       4.0|   3.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|        0.75|          3.0|     16.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|22.0|         1.5|          2.0|     16.0|       5.0|   5.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|27.0|        10.0|          2.0|     14.0|       1.0|   5.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+-------------+-------------+ 
only showing top 20 rows 
  
   
val encodeDF: DataFrame = encoded1
encodeDF: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 11 more fields] 
   
encodeDF.show() 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+-------------+-------------+ 
|affairs| age|yearsmarried|religiousness|education|occupation|rating|gender|children|genderIndex|    genderVec|childrenIndex|  childrenVec| 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+-------------+-------------+ 
|    0.0|37.0|        10.0|          3.0|     18.0|       7.0|   4.0|  male|      no|        1.0|(2,[1],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     14.0|       6.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|32.0|        15.0|          1.0|     12.0|       1.0|   4.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|57.0|        15.0|          5.0|     18.0|       6.0|   5.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|        0.75|          2.0|     17.0|       6.0|   3.0|  male|      no|        1.0|(2,[1],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|32.0|         1.5|          2.0|     17.0|       5.0|   5.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|22.0|        0.75|          2.0|     12.0|       1.0|   3.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|57.0|        15.0|          2.0|     14.0|       4.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|32.0|        15.0|          4.0|     16.0|       1.0|   2.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|         1.5|          4.0|     14.0|       4.0|   5.0|  male|      no|        1.0|(2,[1],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|37.0|        15.0|          2.0|     20.0|       7.0|   2.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     18.0|       6.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|47.0|        15.0|          5.0|     17.0|       6.0|   4.0|  male|     yes|        1.0|(2,[1],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|         1.5|          2.0|     17.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|27.0|         4.0|          4.0|     14.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|37.0|        15.0|          1.0|     17.0|       5.0|   5.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|37.0|        15.0|          2.0|     18.0|       4.0|   3.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
|    0.0|22.0|        0.75|          3.0|     16.0|       5.0|   4.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|22.0|         1.5|          2.0|     16.0|       5.0|   5.0|female|      no|        0.0|(2,[0],[1.0])|          1.0|(2,[1],[1.0])| 
|    0.0|27.0|        10.0|          2.0|     14.0|       1.0|   5.0|female|     yes|        0.0|(2,[0],[1.0])|          0.0|(2,[0],[1.0])| 
+-------+----+------------+-------------+---------+----------+------+------+--------+-----------+-------------+-------------+-------------+ 
only showing top 20 rows 
   
   
encodeDF.printSchema() 
root 
 |-- affairs: double (nullable = true) 
 |-- age: double (nullable = true) 
 |-- yearsmarried: double (nullable = true) 
 |-- religiousness: double (nullable = true) 
 |-- education: double (nullable = true) 
 |-- occupation: double (nullable = true) 
 |-- rating: double (nullable = true) 
 |-- gender: string (nullable = true) 
 |-- children: string (nullable = true) 
 |-- genderIndex: double (nullable = true) 
 |-- genderVec: vector (nullable = true) 
 |-- childrenIndex: double (nullable = true) 
 |-- childrenVec: vector (nullable = true) 

 

4.將欄位組合成向量feature

//將欄位組合成向量feature 
val assembler = new VectorAssembler().setInputCols(Array("affairs", "age", "yearsmarried", "religiousness", "education", "occupation", "rating", "genderVec", "childrenVec")).setOutputCol("features") 
assembler: org.apache.spark.ml.feature.VectorAssembler = vecAssembler_df76d5d1e3f4
   
val vecDF: DataFrame = assembler.transform(encodeDF) 
vecDF: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 12 more fields] 
   
vecDF.select("features").show 
+--------------------+ 
|            features| 
+--------------------+ 
|[0.0,37.0,10.0,3....| 
|[0.0,27.0,4.0,4.0...| 
|[0.0,32.0,15.0,1....| 
|[0.0,57.0,15.0,5....| 
|[0.0,22.0,0.75,2....| 
|[0.0,32.0,1.5,2.0...| 
|[0.0,22.0,0.75,2....| 
|[0.0,57.0,15.0,2....| 
|[0.0,32.0,15.0,4....| 
|[0.0,22.0,1.5,4.0...| 
|[0.0,37.0,15.0,2....| 
|[0.0,27.0,4.0,4.0...| 
|[0.0,47.0,15.0,5....| 
|[0.0,22.0,1.5,2.0...| 
|[0.0,27.0,4.0,4.0...| 
|[0.0,37.0,15.0,1....| 
|[0.0,37.0,15.0,2....| 
|[0.0,22.0,0.75,3....| 
|[0.0,22.0,1.5,2.0...| 
|[0.0,27.0,10.0,2....| 
+--------------------+ 
only showing top 20 rows 

 

5.標準化--均值標準差

// 標準化--均值標準差 
val scaler = new StandardScaler().setInputCol("features").setOutputCol("scaledFeatures").setWithStd(true).setWithMean(true) 
scaler: org.apache.spark.ml.feature.StandardScaler = stdScal_43d3da1cd3bf 
   
// Compute summary statistics by fitting the StandardScaler. 
val scalerModel = scaler.fit(vecDF) 
scalerModel: org.apache.spark.ml.feature.StandardScalerModel = stdScal_43d3da1cd3bf 
   
// Normalize each feature to have unit standard deviation. 
val scaledData: DataFrame = scalerModel.transform(vecDF) 
scaledData: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 13 more fields] 
   
scaledData.select("features", "scaledFeatures").show 
+--------------------+--------------------+ 
|            features|      scaledFeatures| 
+--------------------+--------------------+ 
|[0.0,37.0,10.0,3....|[-0.4413500298573...| 
|[0.0,27.0,4.0,4.0...|[-0.4413500298573...| 
|[0.0,32.0,15.0,1....|[-0.4413500298573...| 
|[0.0,57.0,15.0,5....|[-0.4413500298573...| 
|[0.0,22.0,0.75,2....|[-0.4413500298573...| 
|[0.0,32.0,1.5,2.0...|[-0.4413500298573...| 
|[0.0,22.0,0.75,2....|[-0.4413500298573...| 
|[0.0,57.0,15.0,2....|[-0.4413500298573...| 
|[0.0,32.0,15.0,4....|[-0.4413500298573...| 
|[0.0,22.0,1.5,4.0...|[-0.4413500298573...| 
|[0.0,37.0,15.0,2....|[-0.4413500298573...| 
|[0.0,27.0,4.0,4.0...|[-0.4413500298573...| 
|[0.0,47.0,15.0,5....|[-0.4413500298573...| 
|[0.0,22.0,1.5,2.0...|[-0.4413500298573...| 
|[0.0,27.0,4.0,4.0...|[-0.4413500298573...| 
|[0.0,37.0,15.0,1....|[-0.4413500298573...| 
|[0.0,37.0,15.0,2....|[-0.4413500298573...| 
|[0.0,22.0,0.75,3....|[-0.4413500298573...| 
|[0.0,22.0,1.5,2.0...|[-0.4413500298573...| 
|[0.0,27.0,10.0,2....|[-0.4413500298573...| 
+--------------------+--------------------+ 
only showing top 20 rows 

 

6.主成分PCA

// 主成分 
val pca = new PCA().setInputCol("scaledFeatures").setOutputCol("pcaFeatures").setK(3).fit(scaledData) 
   
pca.explainedVariance.values //解釋變數方差 
res11: Array[Double] = Array(0.28779526464781313, 0.23798543640278289, 0.11742828783633019) 
   
pca.pc //載荷(觀測變數與主成分的相關係數) 
res12: org.apache.spark.ml.linalg.DenseMatrix =
-0.12034310848156521  0.05153952289637974   0.6678769450480689
-0.42860623714516627  0.05417889891307473   -0.05592377098140197
-0.44404074412877986  0.1926596811059294    -0.017025575192258197
-0.12233707317255231  0.08053139375662526   -0.5093149296300096
-0.14664751606128462  -0.3872166556211308   -0.03406819489501708
-0.145543746024348    -0.43054860653839705  0.07841454709046872
0.17703994181974803   -0.12792784984216296  -0.5173229755329072
0.2459668445061567    0.4915809641798787    0.010477548320795945
-0.2459668445061567   -0.4915809641798787   -0.010477548320795945
-0.44420980045271047  0.240652448514566     -0.089356723885704
0.4442098004527103    -0.24065244851456588  0.08935672388570405
   
pca.extractParamMap() 
res13: org.apache.spark.ml.param.ParamMap =
{ 
    pca_40a453a54776-inputCol: scaledFeatures, 
    pca_40a453a54776-k: 3, 
    pca_40a453a54776-outputCol: pcaFeatures 
} 
   
pca.params 
res14: Array[org.apache.spark.ml.param.Param[_]] = Array(pca_40a453a54776__inputCol, pca_40a453a54776__k, pca_40a453a54776__outputCol) 
   
  
   
val pcaDF: DataFrame = pca.transform(scaledData) 
pcaDF: org.apache.spark.sql.DataFrame = [affairs: double, age: double ... 14 more fields] 
   
pcaDF.cache() 
res15: pcaDF.type = [affairs: double, age: double ... 14 more fields] 
   
   
pcaDF.printSchema() 
root 
 |-- affairs: double (nullable = true) 
 |-- age: double (nullable = true) 
 |-- yearsmarried: double (nullable = true) 
 |-- religiousness: double (nullable = true) 
 |-- education: double (nullable = true) 
 |-- occupation: double (nullable = true) 
 |-- rating: double (nullable = true) 
 |-- gender: string (nullable = true) 
 |-- children: string (nullable = true) 
 |-- genderIndex: double (nullable = true) 
 |-- genderVec: vector (nullable = true) 
 |-- childrenIndex: double (nullable = true) 
 |-- childrenVec: vector (nullable = true) 
 |-- features: vector (nullable = true) 
 |-- scaledFeatures: vector (nullable = true) 
 |-- pcaFeatures: vector (nullable = true) 
   
   
pcaDF.select("features", "scaledFeatures", "pcaFeatures").show 
+--------------------+--------------------+--------------------+ 
|            features|      scaledFeatures|         pcaFeatures| 
+--------------------+--------------------+--------------------+ 
|[0.0,37.0,10.0,3....|[-0.4413500298573...|[0.27828160409293...| 
|[0.0,27.0,4.0,4.0...|[-0.4413500298573...|[2.42147114101165...| 
|[0.0,32.0,15.0,1....|[-0.4413500298573...|[0.18301418047489...| 
|[0.0,57.0,15.0,5....|[-0.4413500298573...|[-2.9795960667914...| 
|[0.0,22.0,0.75,2....|[-0.4413500298573...|[1.79299133565688...| 
|[0.0,32.0,1.5,2.0...|[-0.4413500298573...|[2.65694237441759...| 
|[0.0,22.0,0.75,2....|[-0.4413500298573...|[3.48234503794570...| 
|[0.0,57.0,15.0,2....|[-0.4413500298573...|[-2.4215838062079...| 
|[0.0,32.0,15.0,4....|[-0.4413500298573...|[-0.6964555195741...| 
|[0.0,22.0,1.5,4.0...|[-0.4413500298573...|[2.18771069800414...| 
|[0.0,37.0,15.0,2....|[-0.4413500298573...|[-2.4259075891377...| 
|[0.0,27.0,4.0,4.0...|[-0.4413500298573...|[-0.7743038356008...| 
|[0.0,47.0,15.0,5....|[-0.4413500298573...|[-2.6176149267534...| 
|[0.0,22.0,1.5,2.0...|[-0.4413500298573...|[2.95788535193022...| 
|[0.0,27.0,4.0,4.0...|[-0.4413500298573...|[2.50146472861263...| 
|[0.0,37.0,15.0,1....|[-0.4413500298573...|[-0.5123817022008...| 
|[0.0,37.0,15.0,2....|[-0.4413500298573...|[-0.9191740114044...| 
|[0.0,22.0,0.75,3....|[-0.4413500298573...|[2.97391491782863...| 
|[0.0,22.0,1.5,2.0...|[-0.4413500298573...|[3.17940505267806...| 
|[0.0,27.0,10.0,2....|[-0.4413500298573...|[0.74585406839527...| 
+--------------------+--------------------+--------------------+ 
only showing top 20 rows 

 

7.聚類

// 注意最大迭代次數和輪廓係數 
   
val KSSE = (2 to 20 by 1).toList.map { k =>
      // 聚類
      // Trains a k-means model.
      val kmeans = new KMeans().setK(k).setSeed(1L).setFeaturesCol("scaledFeatures")
      val model = kmeans.fit(scaledData)

      // Evaluate clustering by computing Within Set Sum of Squared Errors.
      val WSSSE = model.computeCost(scaledData)

      // K,實際迭代次數,SSE,聚類類別編號,每類的記錄數,類中心點
      (k, model.getMaxIter, WSSSE, model.summary.cluster, model.summary.clusterSizes, model.clusterCenters)
    }

    // 根據SSE確定K值
    val KSSEdf:DataFrame=KSSE.map{x=>(x._1,x._2,x._3,x._5)}.toDF("K", "MaxIter", "SSE", "clusterSizes")
   
KSSE.foreach(println) 

 

相關文章