Apriori演算法 java程式碼

冰凌其發表於2017-12-26

訓練集:

 

結果:

 

程式碼:

import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.TreeMap;
import java.util.TreeSet;

public class  Apriori {

    public static void main(String[] args) throws Exception {

        // 初始化事務集  
        List<Set<String>> trans = new LinkedList<Set<String>>();
        trans.add(new ItemSet(new String[] { "I1", "I2", "I5" }));
        trans.add(new ItemSet(new String[] { "I2", "I4" }));
        trans.add(new ItemSet(new String[] { "I2", "I3" }));
        trans.add(new ItemSet(new String[] { "I1", "I2", "I4" }));
        trans.add(new ItemSet(new String[] { "I1", "I3" }));
        trans.add(new ItemSet(new String[] { "I2", "I3" }));
        trans.add(new ItemSet(new String[] { "I1", "I3" }));
        trans.add(new ItemSet(new String[] { "I1", "I2", "I3", "I5" }));
        trans.add(new ItemSet(new String[] { "I1", "I2", "I3" }));

        int MSF = 2; // 設定最小支援頻次為2  

        Map<Integer, Set<ItemSet>> rst = findFrequentItemSets(trans, MSF);

        // 輸出頻繁項集  
        System.out.println("Frequent Item Sets:");
        for (Entry<Integer, Set<ItemSet>> entry : rst.entrySet()) {
            Integer itemSetSize = entry.getKey();
            System.out.printf("Frequent %d Item Sets:\n", itemSetSize);
            for (ItemSet set : entry.getValue())
                System.out.printf("%s, %d\n", set, set.frequence);
        }

        double MCONF = 0.6; // 設定最小置信度為60%  

        Map<ItemSet, ItemSet> directMap = new HashMap<ItemSet, ItemSet>();
        for (Entry<Integer, Set<ItemSet>> entry : rst.entrySet()) {
            for (ItemSet set : entry.getValue())
                directMap.put(set, set);
        }

        // 根據頻繁項集構造關聯規則  
        System.out.println();
        System.out.println("Association Rules:");
        for (Entry<Integer, Set<ItemSet>> entry : rst.entrySet()) {
            for (ItemSet set : entry.getValue()) {
                double cnt1 = directMap.get(set).frequence;
                List<ItemSet> subSets = set.listNotEmptySubItemSets();
                for (ItemSet subSet : subSets) {
                    int cnt2 = directMap.get(subSet).frequence;
                    double conf = cnt1 / cnt2;
                    if (cnt1 / cnt2 >= MCONF) {
                        ItemSet remainSet = new ItemSet();
                        remainSet.addAll(set);
                        remainSet.removeAll(subSet);
                        System.out.printf("%s => %s, %.2f\n", subSet,
                                remainSet, conf);
                    }
                }
            }
        }
    }

    /**
     * 查詢事務集中的所有頻繁項集,返回Map為:L -> 所有頻繁L項集的列表 
     */
    static Map<Integer, Set<ItemSet>> findFrequentItemSets(
            Iterable<Set<String>> transIterable, int MSF) {
        Map<Integer, Set<ItemSet>> ret = new TreeMap<Integer, Set<ItemSet>>();

        // 首先確定頻繁1項集  
        Iterator<Set<String>> it = transIterable.iterator();
        Set<ItemSet> oneItemSets = findFrequentOneItemSets(it, MSF);
        ret.put(1, oneItemSets);

        int preItemSetSize = 1;
        Set<ItemSet> preItemSets = oneItemSets;

        // 基於獲得的所有頻繁L-1項集迭代查詢所有頻繁L項集,直到不存在頻繁L-1項集  
        while (!preItemSets.isEmpty()) {
            int curItemSetSize = preItemSetSize + 1;

            // 獲取頻繁L項集的所有候選L項集  
            List<ItemSet> candidates = aprioriGenCandidates(preItemSets);

            // 掃描事務集以確定所有候選L項集出現的頻次  
            it = transIterable.iterator();
            while (it.hasNext()) {
                Set<String> tran = it.next();
                for (ItemSet candidate : candidates)
                    if (tran.containsAll(candidate))
                        candidate.frequence++;
            }

            // 將出現頻次不小於最小支援頻次的候選L項集選為頻繁L項集  
            Set<ItemSet> curItemSets = new HashSet<ItemSet>();
            for (ItemSet candidate : candidates)
                if (candidate.frequence >= MSF)
                    curItemSets.add(candidate);
            if (!curItemSets.isEmpty())
                ret.put(curItemSetSize, curItemSets);

            preItemSetSize = curItemSetSize;
            preItemSets = curItemSets;
        }
        return ret;
    }

    /**
     * 掃描事務集以確定頻繁1項集 
     */
    static Set<ItemSet> findFrequentOneItemSets(Iterator<Set<String>> trans,
                                                int MSF) {

        // 掃描事務集以確定各個項出現的頻次  
        Map<String, Integer> frequences = new HashMap<String, Integer>();
        while (trans.hasNext()) {
            Set<String> tran = trans.next();
            for (String item : tran) {
                Integer frequence = frequences.get(item);
                frequence = frequence == null ? 1 : frequence + 1;
                frequences.put(item, frequence);
            }
        }

        // 用每個出現頻次不小於最小支援頻次的項構造一個頻繁1項集  
        Set<ItemSet> ret = new HashSet<ItemSet>();
        for (Entry<String, Integer> entry : frequences.entrySet()) {
            String item = entry.getKey();
            Integer frequence = entry.getValue();
            if (frequence >= MSF) {
                ItemSet set = new ItemSet(new String[] { item });
                set.frequence = frequence;
                ret.add(set);
            }
        }
        return ret;
    }

    /**
     * 根據所有頻繁L-1項集獲得所有頻繁L項集的候選L項集 
     */
    static List<ItemSet> aprioriGenCandidates(Set<ItemSet> preItemSets) {
        List<ItemSet> ret = new LinkedList<ItemSet>();

        // 嘗試將所有頻繁L-1項集兩兩連線然後作剪枝處理以獲得候選L項集  
        for (ItemSet set1 : preItemSets) {
            for (ItemSet set2 : preItemSets) {
                if (set1 != set2 && set1.canMakeJoin(set2)) {

                    // 連線  
                    ItemSet union = new ItemSet();
                    union.addAll(set1);
                    union.add(set2.last());

                    // 剪枝  
                    boolean missSubSet = false;
                    List<ItemSet> subItemSets = union.listDirectSubItemSets();
                    for (ItemSet itemSet : subItemSets) {
                        if (!preItemSets.contains(itemSet)) {
                            missSubSet = true;
                            break;
                        }
                    }
                    if (!missSubSet)
                        ret.add(union);
                }
            }
        }
        return ret;
    }

    /**
     * 由多個項組成的項集,每個項是一個字串。使用TreeSet使項集中的項有序,以輔助演算法實現 
     */
    static class ItemSet extends TreeSet<String> {

        private static final long serialVersionUID = 23883315835136949L;

        int frequence; // 項集出現的頻次  

        public ItemSet() {
            this(new String[0]);
        }

        public ItemSet(String[] items) {
            for (String item : items)
                add(item);
        }

        /**
         * 測試本項集(假定階為L-1)能否與別一個項集連線以生成L階項集 
         */
        public boolean canMakeJoin(ItemSet other) {

            // 若兩個項集的階不同,則不能連線生成L階項集  
            if (other.size() != this.size())
                return false;

            // 假定項集的階為L-1,在項有序的前提下,當且僅當兩個項集的前L-2個項相同  
            // 而本項集的第L-1個項小於另一個項集的第L-1個項時,可以連線生成L階項集  
            Iterator<String> it1 = this.iterator();
            Iterator<String> it2 = other.iterator();
            while (it1.hasNext()) {
                String item1 = it1.next();
                String item2 = it2.next();
                int result = item1.compareTo(item2);
                if (result != 0) {
                    if (it1.hasNext())
                        return false;
                    return result < 0 ? true : false;
                }
            }
            return false;
        }

        /**
         * 假定本項集的階為L,列舉本項集的所有階為L-1的子項集 
         */
        public List<ItemSet> listDirectSubItemSets() {
            List<ItemSet> ret = new LinkedList<ItemSet>();

            // 只有本項集的階大於1,才可能存在非空子項集  
            if (size() > 1) {
                for (String rmItem : this) {
                    ItemSet subSet = new ItemSet();
                    subSet.addAll(this);
                    subSet.remove(rmItem);
                    ret.add(subSet);
                }
            }

            return ret;
        }

        /**
         * 列出本項集除自身外的所有非空子項集 
         */
        public List<ItemSet> listNotEmptySubItemSets() {
            List<ItemSet> ret = new LinkedList<ItemSet>();
            int size = size();
            if (size > 0) {
                char[] mapping = new char[size()];
                initMapping(mapping);
                while (nextMapping(mapping)) {
                    ItemSet set = new ItemSet();
                    Iterator<String> it = this.iterator();
                    for (int i = 0; i < size; i++) {
                        String item = it.next();
                        if (mapping[i] == '1')
                            set.add(item);
                    }
                    if (set.size() < size)
                        ret.add(set);
                }
            }
            return ret;
        }

        private void initMapping(char[] mapping) {
            for (int i = 0; i < mapping.length; i++)
                mapping[i] = '0';
        }

        private boolean nextMapping(char[] mapping) {
            int pos = 0;
            while (pos < mapping.length && mapping[pos] == '1') {
                mapping[pos] = '0';
                pos++;
            }
            if (pos < mapping.length) {
                mapping[pos] = '1';
                return true;
            }
            return false;
        }
    }
}  


相關文章