K-Means演算法的程式碼實現(Java)

呂建奎發表於2015-11-02
//package cn.edu.pku.ss.dm.cluster;

import java.io.BufferedReader;

import java.io.BufferedWriter;

import java.io.FileNotFoundException;

import java.io.FileReader;

import java.io.FileWriter;

import java.io.IOException;

import java.util.ArrayList;

 

//K-means演算法實現

 

public class KMeans {

    //聚類的數目

    final static int ClassCount = 3;

    //樣本數目(測試集)

    final static int InstanceNumber = 150;  

    //樣本屬性數目(測試)

    final static int FieldCount = 5;

    

    //設定異常點閾值引數(每一類初始的最小數目為InstanceNumber/ClassCount^t)

    final static double t = 2.0;

    //存放資料的矩陣

    private float[][] data;

    

    //每個類的均值中心

    private float[][] classData;

    

    //噪聲集合索引

    private ArrayList<Integer> noises;

    

    //存放每次變換結果的矩陣

    private ArrayList<ArrayList<Integer>> result;

    

    //建構函式,初始化

    public KMeans()

    {

   //最後一位用來儲存結果

   data = new float[InstanceNumber][FieldCount+1];

   classData = new float[ClassCount][FieldCount];

   result = new ArrayList<ArrayList<Integer>>(ClassCount);

   noises = new ArrayList<Integer>();

   

    }

    

 

   /**

    * 主函式入口

    * 測試集的檔名稱為“測試集.data”,其中有1000*57大小的資料

    * 每一行為一個樣本,有57個屬性

    * 主要分為兩個步驟

    * 1.讀取資料

    * 2.進行聚類

    * 最後統計執行時間和消耗的記憶體

    * @param args

    */

   public static void main(String[] args) {

      // TODO Auto-generated method stub

       long startTime = System.currentTimeMillis();

       KMeans cluster = new KMeans();

       //讀取資料

       cluster.readData("D:/test.txt");

       //聚類過程

       cluster.cluster();

       //輸出結果

       cluster.printResult("clusterResult.data");

       long endTime = System.currentTimeMillis();

       System.out.println("Total Time:"+ (endTime - startTime)/1000+"s");

       System.out.println("Memory Consuming:"+(float)(Runtime.getRuntime().totalMemory() -

          Runtime.getRuntime().freeMemory())/1000000 + "MB");

   }

        /*

         * 讀取測試集的資料

         * 

         * @param trainingFileName 測試集檔名

         */

   public void readData(String trainingFileName)

   {

       try

       {

      FileReader fr = new FileReader(trainingFileName);

      BufferedReader br = new BufferedReader(fr);

      //存放資料的臨時變數

      String lineData = null;

      String[] splitData = null;

      int line = 0;

      //按行讀取

      while(br.ready())

      {

          //得到原始的字串

          lineData = br.readLine();

          splitData = lineData.split(",");

          //轉化為資料

//        System.out.println("length:"+splitData.length);

          if(splitData.length>1)

          {

             for(int i = 0;i < splitData.length;i++)

             {

//              System.out.println(splitData[i]);

//              System.out.println(splitData[i].getClass());

                if(splitData[i].startsWith("Iris-setosa"))

                {

                   data[line][i] = (float) 1.0;

                }

                else if(splitData[i].startsWith("Iris-versicolor"))

                {

                   data[line][i] = (float) 2.0;

                }

                else if(splitData[i].startsWith("Iris-virginica"))

                {

                   data[line][i] = (float) 3.0;

                }

                else

                {   //將資料擷取之後放進陣列 

                   data[line][i] = Float.parseFloat(splitData[i]);

                }

             }

             line++;

          }

      }

      System.out.println(line);

       }catch(IOException e)

       {

      e.printStackTrace();

       }

   }

   /*

    * 聚類過程,主要分為兩步

    * 1.迴圈找初始點

    * 2.不斷調整直到分類不再發生變化

    */

   public void cluster()

   {

       //資料歸一化

       normalize();

       //標記是否需要重新找初始點

       boolean needUpdataInitials = true;

       

       //找初始點的迭代次數

       int times = 1;

       //找初始點

       while(needUpdataInitials)

       {

      needUpdataInitials = false;

      result.clear();

      System.out.println("Find Initials Iteration"+(times++)+"time(s)");

      

      //一次找初始點的嘗試和根據初始點的分類

      findInitials();

      firstClassify();

      

      //如果某個分類的數目小於特定的閾值,則認為這個分類中的所有樣本都是噪聲點

      //需要重新找初始點

      for(int i = 0;i < result.size();i++)

      {

          if(result.get(i).size() < InstanceNumber/Math.pow(ClassCount,t))

          {

         needUpdataInitials = true;

         noises.addAll(result.get(i));

          }

      }

       }

       

       //找到合適的初始點後

       //不斷的調整均值中心和分類,直到不再發生任何變化

       Adjust();

   }

   

   /*

    * 對資料進行歸一化

    * 1.找每一個屬性的最大值

    * 2.對某個樣本的每個屬性除以其最大值

    */

   public void normalize()

   {

       //找最大值

       float[] max = new float[FieldCount];

       for(int i = 0;i < InstanceNumber;i++)

       {

      for(int j = 0;j < FieldCount;j++)

      {

          if(data[i][j] > max[j])

         max[j] = data[i][j];

      }

       }

       

       //歸一化

       for(int i = 0;i < InstanceNumber;i++)

       {

      for(int j = 0;j < FieldCount;j++)

      {

          data[i][j] = data[i][j]/max[j];

      }

       }

   }

   

   //關於初始向量的一次找尋嘗試

   public void findInitials()

   {

       //a,b為標誌距離最遠的兩個向量的索引

       int i,j,a,b;

       i = j = a = b = 0;

       

       //最遠距離

       float maxDis = 0;

       

       //已經找到的初始點個數

       int alreadyCls = 2;

       

       //存放已經標記為初始點的向量索引

       ArrayList<Integer> initials = new ArrayList<Integer>();

       

       //從兩個開始

       for(;i < InstanceNumber;i++)

       {

      //噪聲點

      if(noises.contains(i))

          continue;

      //long startTime = System.currentTimeMillis();

      j = i + 1;

      for(;j < InstanceNumber;j++)

      {

          //噪聲點

          if(noises.contains(j))

         continue;

          //找出最大的距離並記錄下來

          float newDis = calDis(data[i],data[j]);

          if(maxDis < newDis)

          {

         a = i;

         b = j;

         maxDis = newDis;

          }

      }

      //long endTime = System.currentTimeMillis();

      //System.out.println(i + "Vector Caculation Time:"+(endTime-startTime)+"ms");

       }

       

       //將前兩個初始點記錄下來

       initials.add(a);

       initials.add(b);

       classData[0] = data[a];

       classData[1] = data[b];

       

       //在結果中新建存放某樣本索引的物件,並把初始點新增進去

       ArrayList<Integer> resultOne = new ArrayList<Integer>();

       ArrayList<Integer> resultTwo = new ArrayList<Integer>();

       resultOne.add(a);

       resultTwo.add(b);

       result.add(resultOne);

       result.add(resultTwo);

       

       //找到剩餘的幾個初始點

       while(alreadyCls < ClassCount)

       {

      i = j = 0;

      float maxMin = 0;

      int newClass = -1;

      

      //找最小值中的最大值

      for(;i < InstanceNumber;i++)

      {

          float min = 0;

          float newMin = 0;

          //找和已有類的最小值

          if(initials.contains(i))

         continue;

          //噪聲點去除

          if(noises.contains(i))

         continue;

          for(j = 0;j < alreadyCls;j++)

          {

         newMin = calDis(data[i],classData[j]);

         if(min == 0 || newMin < min)

             min = newMin;

          }

          

          //新最小距離較大

          if(min > maxMin)

          {

         maxMin = min;

         newClass = i;

          }

      }

      //新增到均值集合和結果集合中

      //System.out.println("NewClass"+newClass);

      initials.add(newClass);

      classData[alreadyCls++] = data[newClass];

      ArrayList<Integer> rslt = new ArrayList<Integer>();

      rslt.add(newClass);

      result.add(rslt);

       }

   }

   

   //第一次分類

   public void firstClassify()

   {

       //根據初始向量分類

       for(int i = 0;i < InstanceNumber;i++)

       {

      float min = 0f;

      int clsId = -1;

      for(int j = 0;j < classData.length;j++)

      {

          //歐式距離

          float newMin = calDis(classData[j],data[i]);

          if(clsId == -1 || newMin <min)

          {

         clsId = j;

         min = newMin;

          }

          

      }

      //本身不再新增

      if(!result.get(clsId).contains(i))

          result.get(clsId).add(i);

       }

   }

   //迭代分類,直到各個類的資料不再變化

   public void Adjust()

   {

       //記錄是否發生變化

       boolean change = true;

       

       //迴圈的次數

       int times = 1;

       while(change)

       {

      //復位

      change = false;

      System.out.println("Adjust Iteration"+(times++)+"time(s)");

                    

      //重新計算每個類的均值  

      for(int i = 0;i < ClassCount; i++){  

      //原有的資料  

      ArrayList<Integer> cls = result.get(i);  

        

      //新的均值  

      float[] newMean = new float[FieldCount ];  

       

      //計算均值  

      for(Integer index:cls){  

       for(int j = 0;j < FieldCount ;j++)  

              newMean[j] += data[index][j];  

       }  

      for(int j = 0;j < FieldCount ;j++)  

         newMean[j] /= cls.size();  

      if(!compareMean(newMean, classData[i])){  

         classData[i] = newMean;  

           change = true;  

           }  

      }  

      //清空之前的資料  

      for(ArrayList<Integer> cls:result)  

       cls.clear();  

         

      //重新分配  

      for(int i = 0;i < InstanceNumber;i++)  

      {  

       float min = 0f;  

       int clsId = -1;  

       for(int j = 0;j < classData.length;j++){  

        float newMin = calDis(classData[j], data[i]);  

       if(clsId == -1 || newMin < min){  

         clsId = j;  

           min = newMin;  

               }  

                 }  

                   data[i][FieldCount] = clsId;  

                    result.get(clsId).add(i);  

              }  

                  

         //測試聚類效果(訓練集)  

      //          for(int i = 0;i < ClassCount;i++){  

      //              int positives = 0;  

      //              int negatives = 0;  

      //              ArrayList<Integer> cls = result.get(i);  

      //              for(Integer instance:cls)  

      //                  if (data[instance][FieldCount - 1] == 1f)  

      //                      positives ++;  

      //                  else  

      //                      negatives ++;  

      //              System.out.println(" " + i + " Positive: " + positives + " Negatives: " + negatives);  

      //          }  

      //          System.out.println();  

       } 

                

                

   }  

           

         /** 

           * 計算a樣本和b樣本的歐式距離作為不相似度 

           *  

           * @param a     樣本a 

           * @param b     樣本b 

           * @return      歐式距離長度 

           */  

   private float calDis(float[] aVector,float[] bVector)  { 

      double dis = 0; 

      int i = 0;

               /*最後一個資料在訓練集中為結果,所以不考慮  */

                for(;i < aVector.length;i++)

                     dis += Math.pow(bVector[i] - aVector[i],2);  

                dis = Math.pow(dis, 0.5);  

                return (float)dis;  

   }

          

        /** 

         * 判斷兩個均值向量是否相等 

         *  

         * @param a 向量a 

              * @param b 向量b 

         * @return 

         */  

       private boolean compareMean(float[] a,float[] b)  

       {  

             if(a.length != b.length)  

               return false;  

             for(int i =0;i < a.length;i++){  

             if(a[i] > 0 &&b[i] > 0&& a[i] != b[i]){  

                  return false;  

                }     

            }  

              return true;  

        }  

           

        /** 

         * 將結果輸出到一個檔案中 

         *  

              * @param fileName 

              */  

         public void printResult(String fileName)  

       {  

       FileWriter fw = null;  

            BufferedWriter bw = null;  

            try {  

                  fw = new FileWriter(fileName);  

               bw = new BufferedWriter(fw);  

              //寫入檔案  

               for(int i = 0;i < InstanceNumber;i++)  

               {  

                  bw.write(String.valueOf(data[i][FieldCount]).substring(0, 1));  

                   bw.newLine();  

                }  

                

               //統計每類的數目,列印到控制檯  

               for(int i = 0;i < ClassCount;i++)  

             {  

                     System.out.println("第" + (i+1) + "類數目: " + result.get(i).size());  

              }  

         } catch (IOException e) {  

             e.printStackTrace();  

           } finally{  

                  

               //關閉資源  

             if(bw != null)  

                   try {  

                     bw.close();  

                   } catch (IOException e) {  

                       e.printStackTrace();  

                  }  

               if(fw != null)  

                    try {  

                        fw.close();  

                   } catch (IOException e) {  

                         e.printStackTrace();  

                    }  

             }  

              

        }  

      }

相關文章