資料探勘聚類之k-medoids演算法實現

勿在浮沙築高臺LS發表於2017-01-20

程式碼如下:

package com.winning.dm.pathway;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Map;

import com.winning.dm.test.util.StringUtils;

public class NewKmeans {

    // 儲存聚類之後的資料
    public ArrayList<ArrayList<String>> groupby;
    // 儲存中心
    public ArrayList<String> alcenter;
    // 迭代後的中心
    public ArrayList<String> alnewcenter;
    // 計算幾個類的輪廓係數
    public ArrayList<Double> adoutline;
    private static Logger logger = Logger.getLogger(NewKmeans.class);

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-6 下午4:35:55
     * @描述: 計算資料的漢明距離
     * @param x
     *            字串x
     * @param y
     *            字串y
     * @return
     * @備註:
     */
    public int distance(String x, String y) {
        int distance;
        if (x.length() != y.length()) {
            distance = -1;
        } else {
            distance = 0;
            for (int i = 0; i < x.length(); i++) {
                if (x.charAt(i) != y.charAt(i)) {
                    distance++;
                }
            }
        }
        return distance;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-6 下午4:43:43
     * @描述: 聚類中心的迭代方法
     * @param list
     * @return
     * @備註:
     */
    public String centerIteration(ArrayList<String> list) {
        int size = list.size();
        int[] array = new int[size];
        for (int i = 0; i < size; i++) {
            for (int j = 0; j < size; j++) {
                array[i] = array[i] + distance(list.get(i), list.get(j));
            }
        }
        return list.get(findMinIndex(array));
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-6 下午5:04:49
     * @描述: 找到陣列中的最小值對應的下標
     * @param array
     *            陣列
     * @return 下標
     * @備註:
     */
    public int findMinIndex(int[] array) {
        int index = 0;
        int size = array.length;
        int min = array[0];
        for (int i = 0; i < size; i++) {
            if (min > array[i]) {
                min = array[i];
                index = i;
            }
        }
        return index;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 上午11:13:00
     * @描述: 計算最小值的下標
     * @param array
     *            目標陣列
     * @return 下標
     * @備註:
     */
    public int findMinIndex(ArrayList<Double> array) {
        int index = 0;
        int size = array.size();
        double min = array.get(0);
        for (int i = 0; i < size; i++) {
            if (min > array.get(i)) {
                min = array.get(i);
                index = i;
            }
        }
        return index;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-6 下午5:16:24
     * @描述: 把標號與ghdjid進行匹配
     * @return 標號與ghdjid的匹配
     * @throws IOException
     * @備註:
     */
    public ArrayList<String> findItem(String outputcsv, int xms)
            throws IOException {
        File file = new File(outputcsv);
        BufferedReader br = new BufferedReader(new FileReader(file));
        ArrayList<String> al = new ArrayList<String>();
        String itemindex = null;
        while ((itemindex = br.readLine()) != null) {
            String[] st = itemindex.split(",");
            int length = st.length;
            StringBuilder sb = new StringBuilder();
            for (int i = 1; i < xms + 1; i++) {
                for (int j = 0; j < length; j++) {
                    if (i == Integer.valueOf(st[j])) {
                        sb.append('1');
                        j = length - 1;
                        continue;
                    }
                    if (Integer.valueOf(st[j]) > i) {
                        sb.append('0');
                        j = length - 1;
                        continue;
                    }
                    if (Integer.valueOf(st[length - 1]) < i) {
                        sb.append('0');
                        j = length - 1;
                        continue;
                    }
                }
            }
//          System.out.println(sb.toString());
            al.add(sb.toString());
        }
        return al;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-10 下午4:42:29
     * @描述: 建立中心,選取患者作為中心
     * @param k
     *            建立中心的數量
     * @return 建立的中心
     * @備註:
     */
    public ArrayList<Integer> createCenter(int k, int zrs) {
        ArrayList<Integer> result = new ArrayList<Integer>();
        for (int i = 0; i < k; i++) {
            int a = (int) (Math.random() * zrs);
            if (i == 0) {
                result.add(a);
            }
            for (int j = 0; j < i; j++) {
                if (a == result.get(j)) {
                    i--;
                } else {
                    result.add(a);
                }
            }
        }
        return result;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 上午9:24:23
     * @描述: 對資料進行分組
     * @param gradedata
     *            標號原始資料
     * @param center
     *            中心
     * @return
     * @備註:
     */
    public ArrayList<ArrayList<String>> groupBy(ArrayList<String> gradedata,
            ArrayList<String> center) {
        ArrayList<ArrayList<String>> groupby = new ArrayList<ArrayList<String>>();
        int zrs = gradedata.size();
        int centernum = center.size();
        for (int i = 0; i < centernum; i++) {
            ArrayList<String> aset = new ArrayList<String>();
            groupby.add(aset);
        }
        for (int i = 0; i < zrs; i++) {
            int[] distance = new int[centernum];
            for (int j = 0; j < centernum; j++) {
                distance[j] = distance(gradedata.get(i), center.get(j));
            }
            groupby.get(findMinIndex(distance)).add(gradedata.get(i));
        }
        return groupby;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 上午10:47:45
     * @描述: 計算輪廓係數之和
     * @param groupby
     * @return
     * @備註:
     */
    public double outline(ArrayList<ArrayList<String>> groupby) {
        // 計算每一個人的s(i)
        int size = groupby.size();

        double sums = 0;
        for (int i = 0; i < size; i++) {
            int num = groupby.get(i).size();
            for (int j = 0; j < num; j++) {
                double a = counta(groupby.get(i).get(j), groupby.get(i));
                double b = countb(i, groupby.get(i).get(j), groupby);
                double max = 0;
                max = (a > b) ? a : b;
                sums = sums + (b - a) / max;
            }
        }
        return sums;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 上午10:54:50
     * @描述: 計算簇內距離
     * @param target
     *            目標值
     * @param group
     *            簇
     * @return
     * @備註:
     */
    public double counta(String target, ArrayList<String> group) {
        int size = group.size();
        double result = 0;
        for (int i = 0; i < size; i++) {
            result = result + distance(target, group.get(i));
        }
        result = result / size;
        return result;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 上午10:55:21
     * @描述: 計算簇間距離
     * @param target
     * @param group
     * @return
     * @備註:
     */
    public double countb(int grade, String target,
            ArrayList<ArrayList<String>> groupby) {
        int size = groupby.size();
        ArrayList<Double> result = new ArrayList<Double>();
        for (int i = 0; i < size; i++) {
            if (i == grade) {
                continue;
            } else {
                int num = groupby.get(i).size();
                double sum = 0;
                for (int j = 0; j < num; j++) {
                    sum = sum + distance(target, groupby.get(i).get(j));
                }
                sum = sum / num;
                result.add(sum);
            }
        }
        return result.get(findMinIndex(result));
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 下午1:26:59
     * @描述: 判斷最大值的索引
     * @param array
     *            目標集合
     * @return
     * @備註:
     */
    public int findMaxIndex(ArrayList<Double> array) {
        int index = 0;
        int size = array.size();
        double max = array.get(0);
        for (int i = 0; i < size; i++) {
            if (max < array.get(i)) {
                max = array.get(i);
                index = i;
            }
        }
        return index;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 下午1:58:05
     * @描述: 根據輪廓係數找到簇的個數
     * @param gradedata
     *            原始資料
     * @return 簇的個數
     * @備註:
     */
    public int findGroupNum(ArrayList<String> gradedata) {
        adoutline = new ArrayList<Double>();
        for (int i = 2; i < 4; i++) {
            // 隨機進行迭代,多次迭代求平均的輪廓係數
            double sumoutline = 0;
            for (int t = 0; t < 10; t++) {
                // 隨機生成序號
                ArrayList<Integer> center = createCenter(i, gradedata.size());

                alcenter = new ArrayList<String>();
                for (int j = 0; j < i; j++) {
                    // 把序號轉成標號資料
                    alcenter.add(gradedata.get(center.get(j)));
                }
                groupby = groupBy(gradedata, alcenter);
                if (StringUtils.isEmpty(groupby)) {
                    t--;
                    continue;
                }
                // 檢視第一次是不是有空的集合
                double avgdistance = 20;
                while (avgdistance > 5) {
                    // 按照中心對資料進行分組,多次迭代求平均輪廓係數
                    groupby = groupBy(gradedata, alcenter);
                    alnewcenter = new ArrayList<String>();
                    for (int j = 0; j < i; j++) {
                        // 迭代求出新的中心
                        alnewcenter.add(centerIteration(groupby.get(j)));// 中心迭代的時候出現空的陣列
                    }
                    // 計算新的中心與舊的中心之間的距離
                    int sumdistance = 0;
                    for (int j = 0; j < i; j++) {
                        // 計算兩個中心之間的距離
                        sumdistance = sumdistance
                                + distance(alcenter.get(j), alnewcenter.get(j));
                    }
                    avgdistance = (double) sumdistance / (double) i;
                    alcenter = alnewcenter;
                }
                // 根據分組計算輪廓係數
                sumoutline = sumoutline + outline(groupby);
            }
            adoutline.add(sumoutline);
        }
        // k為最終的分類數量
        int k = 2 + findMaxIndex(adoutline);
        return k;
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 下午2:35:09
     * @描述: 找到分組的中心,分組的中心就是路徑
     * @param k
     *            簇的個數
     * @param gradedata
     *            原始資料
     * @return 返回最佳中心
     * @備註:
     */
    public ArrayList<String> findCenter(int k, ArrayList<String> gradedata) {
        // 儲存中心和對應的患者
        Map<String, ArrayList<String>> result = new HashMap<String, ArrayList<String>>();
        // 多次隨機生成中心防止得到區域性最優
        ArrayList<ArrayList<String>> aals = new ArrayList<ArrayList<String>>();
        // 儲存輪廓係數
        ArrayList<Double> outline = new ArrayList<Double>();
        for (int t = 0; t < 10; t++) {
            // 隨機生成序號
            ArrayList<String> alcenter = new ArrayList<String>();
            ArrayList<Integer> center = createCenter(k, gradedata.size());
            for (int j = 0; j < k; j++) {
                // 把序號轉成標號資料
                alcenter.add(gradedata.get(center.get(j)));
            }
            aals.add(alcenter);
            double avgdistance = 20;
            while (avgdistance > 5) {
                // 按照中心對資料進行分組,多次迭代求平均輪廓係數
                groupby = groupBy(gradedata, alcenter);
                for (int j = 0; j < k; j++) {
                    // 迭代求出新的中心
                    alnewcenter.add(centerIteration(groupby.get(j)));
                }
                // 計算新的中心與舊的中心之間的距離
                int sumdistance = 0;
                for (int j = 0; j < k; j++) {
                    // 計算兩個中心之間的距離
                    sumdistance = sumdistance
                            + distance(alcenter.get(j), alnewcenter.get(j));
                }
                avgdistance = (double) sumdistance / (double) k;
                alcenter = alnewcenter;
            }
            // 根據分組計算輪廓係數
            outline.add(outline(groupby));
        }
        return aals.get(findMaxIndex(outline));
    }

    /**
     * 
     * @作者: liusen
     * @時間: 2017-1-11 下午4:10:54
     * @描述: 根據聚類中心進行聚類
     * @return 聚類中心和對應的集合
     * @備註:
     */
    public Map<String, ArrayList<String>> findCenterMap(
            ArrayList<String> center, ArrayList<String> gradedata) {
        Map<String, ArrayList<String>> map = new HashMap<String, ArrayList<String>>();
        ArrayList<String> alnewcenter=new ArrayList<String>();
        double avgdistance = 20;
        int size = center.size();
        while (avgdistance > 5) {
            // 按照中心對資料進行分組,多次迭代求平均輪廓係數
            groupby = groupBy(gradedata, center);
            for (int j = 0; j < size; j++) {
                // 迭代求出新的中心
                alnewcenter.add(centerIteration(groupby.get(j)));
            }
            // 計算新的中心與舊的中心之間的距離
            int sumdistance = 0;
            for (int j = 0; j < size; j++) {
                // 計算兩個中心之間的距離
                sumdistance = sumdistance
                        + distance(center.get(j), alnewcenter.get(j));
            }
            avgdistance = (double) sumdistance / (double) size;
            center = alnewcenter;
        }
        for (int i = 0; i < size; i++) {
            map.put(center.get(i), groupby.get(i));
        }
        return map;
    }
}

相關文章