Spark MLlib SVM 文字分類器實現

破棉襖發表於2015-12-30

好久沒寫部落格了,最近搞了一個文字分類器,在此記錄一下:


簡介:

支援向量機,因其英文名為support vector machine,故一般簡稱SVM,通俗來講,它是一種二類分類模型,其基本模型定義為特徵空間上的間隔最大的線性分類器,其學習策略便是間隔最大化,最終可轉化為一個凸二次規劃問題的求解。

1  “機” —— Classification Machine,分類器

2  “支援向量” —— 他們就是離分界線最近的向量。也就是說分介面是靠這些向量確定的,他們支撐著分類面。名字就是這麼來的...(就是離最優分類平面最近的離散點,也可以稱為向量) 


spark自帶了一個svm實現的dome,該dome直接讀取儲存libsvm所需稀疏向量的檔案,但是並未提供向量化方法,需自己呼叫HashingTF、IDF轉換為稀疏向量

程式碼:

  1. /**
  2.  * SVM分類物件
  3.  * @author wangzengxu
  4.  */
  5. object SVM{

  6.     def main(args: Array[String]){
  7.           
  8.      val Array(
  9.          rightPath,      // 正面訓練集路徑
  10.          negativePath,   // 負面訓練集路徑
  11.          waitData,       // 待分類資料存放路徑
  12.          vectorsLocl,    // 向量存放路徑
  13.          iterativeNum    // 迭代次數
  14.          ) = args 
  15.      
  16.      var sparkconf = new SparkConf().setAppName("wzx_svm_classificationsV2")
  17.      var sc = new SparkContext(sparkconf)
  18.      
  19.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/IKAnalyzer2012_u6.jar");
  20.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/lucene-analyzers-common-4.3.0.jar");
  21.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/lucene-core-4.3.0.jar");
  22.      sc.addJar("/usr/wzx/spark/svm/SVM_WZX_lib/lucene-queryparser-4.3.0.jar");
  23.      
  24.      val train_vectors_local = vectorsLocl+"/train-"+DateUtils.getNowDate()  //訓練向量存放目錄 
  25.      val wait_vectors_local = vectorsLocl+"/wait-"+DateUtils.getNowDate()    //待分向量存放目錄 
  26.      
  27.      val data_path_right = rightPath           //正面訓練集文章路徑檔案 每行一篇
  28.      val data_path_negative = negativePath     //負面訓練集文章路徑檔案 每行一篇
  29.      val data_path_wait = waitData             //待分資料存放路徑 
  30.      val iterative_number = iterativeNum.toInt //訓練模型迭代次數
  31.      
  32.      /***********************start 分詞******************************************/
  33.      
  34.        val right_data = sc.textFile(data_path_right)
  35.        
  36.        val negative_data = sc.textFile(data_path_negative)
  37.        
  38.        val wait_data = sc.textFile(data_path_wait)
  39.        
  40.        //去停用詞 
  41.        
  42.        val right_text = right_data.map { x =>
  43.             val str = IKUtils.participle(x)
  44.             (1,str) //正面1
  45.         } 
  46.         
  47.         val negative_text = negative_data.map { x =>
  48.             val str = IKUtils.participle(x)
  49.             (0,str) //負面0
  50.         }
  51.         
  52.         val wait_text = wait_data.map { x =>
  53.             val str = IKUtils.participle(x)
  54.             (2,str) //待分2
  55.         }
  56.         
  57.         val data_all_train = right_text.++(negative_text) //訓練集RDD合併
  58.       
  59.      /***********************end 分詞******************************************/
  60.  
  61.         
  62.         
  63.     /***********************start 向量化******************************************/
  64.         
  65.        val hashingTF = new HashingTF(Math.pow(2, 18).toInt)
  66.         
  67.         //訓練集TF向量化
  68.         val documents_train = data_all_train.map{
  69.           case(num,str) =>
  70.             (num,str.split(" ").toSeq)
  71.         }
  72.        
  73.         val tf_num_pairs_train = documents_train.map {
  74.         case (num,seq) =>
  75.           val tf = hashingTF.transform(seq)
  76.           (num,tf)
  77.         }
  78.         
  79.         //待分類TF向量化
  80.         val documents_wait = wait_text.map{
  81.           case(num,str) =>
  82.             (num,str.split(" ").toSeq)
  83.         }
  84.            
  85.         val tf_num_pairs_wait = documents_wait.map {
  86.           case (num,seq) =>
  87.             val tf = hashingTF.transform(seq)
  88.             (num,tf)
  89.         }
  90.       
  91.         tf_num_pairs_train.cache()
  92.         tf_num_pairs_wait.cache()
  93.      
  94.       
  95.        //利用訓練集TF構建IDF MODEL
  96.        val idf = new IDF().fit(tf_num_pairs_train.values)
  97.       
  98.      
  99.       //將訓練集tf向量轉換成tf-idf向量
  100.       val num_idf_pairs_train = tf_num_pairs_train.mapValues(=> idf.transform(v)) 
  101.       //將待分類資料集tf向量轉換成tf-idf向量
  102.       val num_idf_pairs_wait = tf_num_pairs_wait.mapValues(=> idf.transform(v)) 
  103.       
  104.       //格式轉換 
  105.       val trainCollection = num_idf_pairs_train.map{
  106.         case(num,vector) => 
  107.            val vectorStr = num +" "+VectorToStr.change(vector)
  108.            vectorStr
  109.       } 
  110.         
  111.       val waitCollection = num_idf_pairs_wait.map{
  112.         case(num,vector) => 
  113.            val vectorStr = num +" "+VectorToStr.change(vector)
  114.            vectorStr
  115.       } 
  116.       
  117.       //落地 (後期可參看MLUtils原始碼來直接轉換為LabeledPoint避免落地)
  118.       trainCollection.coalesce(1).saveAsTextFile(train_vectors_local)
  119.       waitCollection.coalesce(1).saveAsTextFile(wait_vectors_local)
  120.       
  121.     /***********************end 向量化******************************************/
  122.         
  123.     /***********************start SVM模型訓練******************************************/
  124.     
  125.         val vectors_train = MLUtils.loadLibSVMFile(sc,train_vectors_local).cache()
  126.         val vectors_wait = MLUtils.loadLibSVMFile(sc,wait_vectors_local).cache()
  127.      
  128.        
  129.         //1 新建SVM模型,並設定訓練引數 
  130.         
  131.         val numIterations = iterative_number    //迭代次數,並非越大越好,需根據訓練集不斷調整來確定該值
  132.         
  133.         val stepSize = 1 
  134.         
  135.         val miniBatchFraction = 1.0             //步長
  136.         
  137.         val model = SVMWithSGD.train(vectors_train, numIterations, stepSize, miniBatchFraction) 
   
    
  1.      /***********************start SVM模型訓練******************************************/
  2.         
  3.         
  4.      /***********************start 分類******************************************/
  5.         
  6.         //4 對待分類資料向量進行分類 
  7.       
  8.         println("---------------訓練完成------------------------")
  9.         
  10.         val prediction_wait = model.predict(vectors_wait.map(_.features))
  11.         
  12.         println("---------------分類完成------------------------")
  13.       
  14.         prediction_wait.saveAsTextFile("/user/wzx/cs1")
  15.       
  16.         println("---------------儲存完成------------------------")
  17.      
  18.      /***********************end 分類******************************************/ 
  19.     
  20.       
  21.    }
  22. }


  1. /**
  2.  * IK分詞 去掉停用詞處理
  3.  * @author wangzengxu
  4.  *
  5.  */
  6. public class IKUtils {
  7.     
  8.      
  9.     
  10.      public static String participle(String text){
  11.      StringBuffer result = new StringBuffer();
  12.      //讀入停用詞檔案
  13.      BufferedReader StopWordFileBr = new BufferedReader(new InputStreamReader(IKUtils.class.getResourceAsStream("/stopword.dic")));    //注意jar包路徑問題
  14.      //用來存放停用詞的集合
  15.      Set<String> stopWordSet = new HashSet<String>();
  16.      //初如化停用詞集
  17.      String stopWord = null;
  18.      try {
  19.             for(; (stopWord = StopWordFileBr.readLine()) != null;){
  20.              stopWordSet.add(stopWord);
  21.              }
  22.         } catch (IOException e) {
  23.             e.printStackTrace();
  24.         }
  25.      //建立分詞物件
  26.      StringReader sr=new StringReader(text);
  27.      IKSegmenter ik=new IKSegmenter(sr, false);
  28.      Lexeme lex=null;
  29.      //分詞
  30.      try {
  31.             while((lex=ik.next())!=null){
  32.              //去除停用詞
  33.              if(stopWordSet.contains(lex.getLexemeText())) {
  34.              continue;
  35.              }
  36.              result.append(lex.getLexemeText()+" ");
  37.              }
  38.         } catch (IOException e) {
  39.             e.printStackTrace();
  40.         }
  41.      //關閉流
  42.      try {
  43.             StopWordFileBr.close();
  44.         } catch (IOException e) {
  45.             e.printStackTrace();
  46.         }
  47.      return result.toString();
  48.      }
  49.      
  50. }


dome中提供了評分程式碼,在模型訓練時需要根據評分來不斷調整迭代次數等來達到滿意的精度。當然,這個dome還有很多最佳化空間



來自 “ ITPUB部落格 ” ,連結:http://blog.itpub.net/29754888/viewspace-1967758/,如需轉載,請註明出處,否則將追究法律責任。

相關文章