Spark學習筆記——手寫數字識別

weixin_33912246發表於2017-05-25
import org.apache.spark.ml.classification.RandomForestClassifier
import org.apache.spark.ml.regression.RandomForestRegressor
import org.apache.spark.mllib.classification.{LogisticRegressionWithLBFGS, NaiveBayes, SVMWithSGD}
import org.apache.spark.mllib.linalg.Vectors
import org.apache.spark.mllib.optimization.L1Updater
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.mllib.regression.LabeledPoint
import org.apache.spark.mllib.tree.{DecisionTree, RandomForest}
import org.apache.spark.mllib.tree.configuration.Algo
import org.apache.spark.mllib.tree.impurity.Entropy

/**
  * Created by common on 17-5-17.
  */

case class LabeledPic(
                       label: Int,
                       pic: List[Double] = List()
                     )

object DigitRecognizer {

  def main(args: Array[String]): Unit = {

    val conf = new SparkConf().setAppName("DigitRecgonizer").setMaster("local")
    val sc = new SparkContext(conf)
    // 去掉第一行,sed 1d train.csv > train_noheader.csv
    val trainFile = "file:///media/common/工作/kaggle/DigitRecognizer/train_noheader.csv"
    val trainRawData = sc.textFile(trainFile)
    // 通過逗號對資料進行分割,生成陣列的rdd
    val trainRecords = trainRawData.map(line => line.split(","))

    val trainData = trainRecords.map { r =>
      val label = r(0).toInt
      val features = r.slice(1, r.size).map(d => d.toDouble)
      LabeledPoint(label, Vectors.dense(features))
    }


    //    // 使用貝葉斯模型
    //    val nbModel = NaiveBayes.train(trainData)
    //
    //    val nbTotalCorrect = trainData.map { point =>
    //      if (nbModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val nbAccuracy = nbTotalCorrect / trainData.count
    //
    //    println("貝葉斯模型正確率:" + nbAccuracy)
    //
    //    // 對測試資料進行預測
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通過逗號對資料進行分割,生成陣列的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = nbModel.predict(testData).map(p => p.toInt)
    //    // 儲存預測結果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")


    //    // 使用線性迴歸模型
    //    val lrModel = new LogisticRegressionWithLBFGS()
    //      .setNumClasses(10)
    //      .run(trainData)
    //
    //    val lrTotalCorrect = trainData.map { point =>
    //      if (lrModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val lrAccuracy = lrTotalCorrect / trainData.count
    //
    //    println("線性迴歸模型正確率:" + lrAccuracy)
    //
    //    // 對測試資料進行預測
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通過逗號對資料進行分割,生成陣列的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = lrModel.predict(testData).map(p => p.toInt)
    //    // 儲存預測結果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict1")


    //    // 使用決策樹模型
    //    val maxTreeDepth = 10
    //    val numClass = 10
    //    val dtModel = DecisionTree.train(trainData, Algo.Classification, Entropy, maxTreeDepth, numClass)
    //
    //    val dtTotalCorrect = trainData.map { point =>
    //      if (dtModel.predict(point.features) == point.label) 1 else 0
    //    }.sum
    //    val dtAccuracy = dtTotalCorrect / trainData.count
    //
    //    println("決策樹模型正確率:" + dtAccuracy)
    //
    //    // 對測試資料進行預測
    //    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
    //    // 通過逗號對資料進行分割,生成陣列的rdd
    //    val testRecords = testRawData.map(line => line.split(","))
    //
    //    val testData = testRecords.map { r =>
    //      val features = r.map(d => d.toDouble)
    //      Vectors.dense(features)
    //    }
    //    val predictions = dtModel.predict(testData).map(p => p.toInt)
    //    // 儲存預測結果
    //    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict2")


//    // 使用隨機森林模型
//    val numClasses = 30
//    val categoricalFeaturesInfo = Map[Int, Int]()
//    val numTrees = 50
//    val featureSubsetStrategy = "auto"
//    val impurity = "gini"
//    val maxDepth = 10
//    val maxBins = 32
//    val rtModel = RandomForest.trainClassifier(trainData, numClasses, categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity, maxDepth, maxBins)
//
//    val rtTotalCorrect = trainData.map { point =>
//      if (rtModel.predict(point.features) == point.label) 1 else 0
//    }.sum
//    val rtAccuracy = rtTotalCorrect / trainData.count
//
//    println("隨機森林模型正確率:" + rtAccuracy)
//
//    // 對測試資料進行預測
//    val testRawData = sc.textFile("file:///media/common/工作/kaggle/DigitRecognizer/test_noheader.csv")
//    // 通過逗號對資料進行分割,生成陣列的rdd
//    val testRecords = testRawData.map(line => line.split(","))
//
//    val testData = testRecords.map { r =>
//      val features = r.map(d => d.toDouble)
//      Vectors.dense(features)
//    }
//    val predictions = rtModel.predict(testData).map(p => p.toInt)
//    // 儲存預測結果
//    predictions.coalesce(1).saveAsTextFile("file:///media/common/工作/kaggle/DigitRecognizer/test_predict")


  }

}

 

相關文章