My Implementation Of Apriori Algorithm

2013/09/14

项目需要分布式的代码,于是先看了下算法思想,自己学着单机实现了一下。复杂度比较高,其中有两个递归实现。由于Java提供的Set都没有实现Comparable接口,所以自定义了一个Set叫做SimpleSet,其中实现了hashCode, equals, compareTo等函数。

以下代码完全没有指导意义,只作为个人工作记录。

代码如下:

SimpleSet.java

import java.io.*;
import java.util.*;

public class SimpleSet<T> implements Set, Comparable{
    static int i  =0;
    Set<T> set = new HashSet<T>();

    public SimpleSet() {

    }
    public SimpleSet(SimpleSet ss) {
        this.set.addAll(ss.set);

    }

    public int hashCode() {
        int type = this.getClass().hashCode();
        int code = 0;
        for(T t : this.set) {
            code += (type*31 + t.hashCode());
        }
        return code;
    }
    public boolean equals(Object obj) {
        boolean flag = false;
        if(obj instanceof SimpleSet) {
            SimpleSet<?> s = (SimpleSet<?>)obj;
            if( (s.size()==this.size()) && s.containsAll(this.set)&& this.set.containsAll(s)) {
                flag = true;
            }
        }
        return flag;
    }
    public String toString() {
        if (this.set.size() == 0) {
            return "";
        } else {
            StringBuilder str = new StringBuilder();
            str.append("[");
            for (T t : this.set) {
                str.append(t.toString() + ",");
            }
            String result = str.substring(0, str.length() - 1);
            result += "]";
            return result;
        }
    }

    @Override
    public boolean add(Object e) {
        return this.set.add((T) e);
    }

    @Override
    public boolean addAll(Collection c) {
        return this.set.addAll(c);
    }

    @Override
    public void clear() {
        this.set.clear();
    }

    @Override
    public boolean contains(Object o) {
        return this.set.contains(o);
    }

    @Override
    public boolean containsAll(Collection c) {
        return this.set.containsAll(c);
    }

    @Override
    public boolean isEmpty() {
        return this.set.isEmpty();
    }

    @Override
    public Iterator iterator() {
        return this.set.iterator();
    }

    @Override
    public boolean remove(Object o) {
        return this.set.remove(o);
    }

    @Override
    public boolean removeAll(Collection c) {
        return this.set.removeAll(c);
    }

    @Override
    public boolean retainAll(Collection c) {
        return this.set.retainAll(c);
    }

    @Override
    public int size() {
        return this.set.size();
    }

    @Override
    public Object[] toArray() {
        return this.set.toArray();
    }

    public String[] toArray(String[] a) {
        return this.set.toArray(a);
    }

    @Override
    public Object[] toArray(Object[] a) {
        // TODO DO NOT USE
        return null;
    }
    @Override
    public int compareTo(Object o) {
        return this.hashCode()-o.hashCode();
    }
}
Apriori.java

import java.util.*;
import java.io.*;

public class Apriori {

    /**
     * @param args
     */
    static Map<String, Integer> data = new HashMap<String, Integer>();
    static List<SimpleSet<String>> lineSet = new ArrayList<SimpleSet<String>>();
    // all constructible combinations
    static Map<SimpleSet<String>, Integer> dataSet = new HashMap<SimpleSet<String>, Integer>();

    static final int MIN_SUP = 1;
    static final float MIN_CONF = 0.1f;
    static BufferedWriter bw1= null;
    static BufferedWriter bw2 = null;

    public static void main(String[] args) throws IOException {
        bw1= new BufferedWriter(new FileWriter("FrequentPatterns.txt", true));
        bw2 = new BufferedWriter(new FileWriter("AssociationRules.txt", true));
        long startTime = System.currentTimeMillis();

        String srcFile = "Test.txt";
        try {
            data = buildData(srcFile);
        } catch (IOException e) {
            e.printStackTrace();
        }
        Map<SimpleSet<String>, Integer> result = getF1Set(data);
        Map<SimpleSet<String>, Integer> maxFICMap = new HashMap<SimpleSet<String>, Integer>();
        int i = 1;
        do {
            printCandidate(result, i);
            i++;
            maxFICMap = result;
            result = recurseGen(result);
        } while (result.size() > 0);
        System.out.println("---------------------------Frequent Patterns---------------------------\n");
        long endTime1 = System.currentTimeMillis();
        System.out.println((endTime1 - startTime)+" ms");
        //System.out.println(dataSet);
        //System.out.println(maxFICMap);

        List<SimpleSet<String>> maxFI = new ArrayList<SimpleSet<String>>(
                maxFICMap.keySet());
        System.out.println("---------------------------Association Rule---------------------------\n");
        calcAndPrintAsso(maxFI);
        long endTime2 = System.currentTimeMillis();
        System.out.println((endTime2 - endTime1)+" ms");

        System.out.println("----------------------------Consumed Time----------------------------\n");
        long endTime3 = System.currentTimeMillis();
        System.out.println((endTime3 - startTime)+" ms");
        bw1.close();
        bw2.close();
    }

    private static void calcAndPrintAsso(List<SimpleSet<String>> maxFI) throws IOException {
        for (SimpleSet<String> oneBigSet : maxFI) {
            if (oneBigSet.size() > 1) {
                if (dataSet.containsKey(oneBigSet)) {
                    int numerator = dataSet.get(oneBigSet);
                    List<SimpleSet<String>> subs = getSubsets(oneBigSet);
                    for (SimpleSet<String> sub : subs) {
                        if (dataSet.containsKey(sub)) {
                            int denominator = dataSet.get(sub);
                            float confidence = (float) numerator / denominator;
                            bw2.write(sub.toString() + " --> " + oneBigSet.toString()
                                    + " conf=" + confidence);
                            bw2.newLine();
                        }
                    }
                    calcAndPrintAsso(subs);
                }
            }
        }
    }

    private static void printCandidate(Map<SimpleSet<String>, Integer> result,
            int i) throws IOException {
        int size = result.size();
        bw1.write(i + "-item set (num: " + size + "):\n");
        bw1.write(result.toString());
        bw1.newLine();
    }

    private static Map<SimpleSet<String>, Integer> recurseGen(
            Map<SimpleSet<String>, Integer> preMap) {
        List<SimpleSet<String>> keyList = new ArrayList<SimpleSet<String>>();
        keyList.addAll(preMap.keySet());
        Map<SimpleSet<String>, Integer> result = new HashMap<SimpleSet<String>, Integer>();
        int preSize = keyList.size();
        for (int i = 0; i < preSize - 1; i++) {
            for (int j = i + 1; j < preSize; j++) {
                SimpleSet<String> key1 = keyList.get(i);
                SimpleSet<String> key2 = keyList.get(j);
                String[] pre1 = key1.toArray(new String[0]);
                String[] pre2 = key2.toArray(new String[0]);
                if (maybeLinkable(pre1, pre2)) {
                    SimpleSet<String> superSet = new SimpleSet<String>();
                    int count = 0;
                    for (String str : pre1) {
                        superSet.add(str);
                    }
                    superSet.add(pre2[pre2.length - 1]);

                    if (!shouldCut(keyList, superSet)) {
                        count = checkSup(superSet);
                        if (count >= MIN_SUP) {
                            result.put(superSet, count);
                            dataSet.put(superSet, count);// construct global
                                                            // reference
                        }
                    }
                }
            }
        }
        return result;
    }

    private static int checkSup(SimpleSet<String> superSet) {
        int count = 0;
        for (SimpleSet<String> line : lineSet) {
            if (line.containsAll(superSet)) {
                count++;
            }
        }
        return count;
    }

    private static boolean shouldCut(List<SimpleSet<String>> preKeyList,
            SimpleSet<String> superSet) {
        boolean flag = false;
        List<SimpleSet<String>> subsets = getSubsets(superSet);
        for (SimpleSet<String> s : subsets) {
            if (!preKeyList.contains(s)) {
                flag = true;
                break;
            }
        }
        return flag;
    }

    private static List<SimpleSet<String>> getSubsets(SimpleSet<String> superSet) {
        List<SimpleSet<String>> subsets = new ArrayList<SimpleSet<String>>();
        String[] superList = superSet.toArray(new String[0]);
        for (int i = 0; i < superList.length; i++) {
            SimpleSet<String> copyOfSuperSet = new SimpleSet<String>(superSet);
            copyOfSuperSet.remove(superList[i]);
            /*
             * SimpleSet<String> temp = new SimpleSet<String>();
             *
             * for (int j = 0; j < superList.length; j++) { if (i != j) {
             * temp.add(superList[i]); } }
             */subsets.add(copyOfSuperSet);
        }
        return subsets;
    }

    private static boolean maybeLinkable(String[] pre1, String[] pre2) {
        boolean flag = true;
        int size1 = pre1.length;
        int size2 = pre2.length;
        if (size1 == size2) {
            if (size1 != 0) {
                for (int i = 0; i < size1 - 1; i++) {
                    if (!pre1[i].equals(pre2[i])) {
                        flag = false;
                        break;
                    }
                }
                if (pre1[size1 - 1].equals(pre2[size2 - 1])) {
                    flag = false;

                }
            } else {
                flag = false;
            }
        } else {
            flag = false;
        }
        return flag;
    }

    private static HashMap<SimpleSet<String>, Integer> getF1Set(
            Map<String, Integer> srcdata) {
        HashMap<SimpleSet<String>, Integer> f1Set = new HashMap<SimpleSet<String>, Integer>();
        for (String str : srcdata.keySet()) {
            int count = srcdata.get(str);
            if (count >= MIN_SUP) {
                SimpleSet<String> set = new SimpleSet<String>();
                set.add(str);
                f1Set.put(set, count);// Be Strict !
                dataSet.put(set, count);
            }
        }
        return f1Set;
    }

    static Map<String, Integer> buildData(String... files) throws IOException {
        Map<String, Integer> tdata = new HashMap<String, Integer>();
        BufferedReader br = new BufferedReader(new FileReader(files[0]));
        String line;
        String[] col;
        while ((line = br.readLine()) != null) {
            SimpleSet<String> line_items = new SimpleSet<String>();
            col = line.split(" ");
            for (String str : col) {
                line_items.add(str);
                if (!tdata.containsKey(str)) {
                    tdata.put(str, 1);
                } else {
                    int i = tdata.get(str);
                    tdata.put(str, i + 1);// TODO
                }
            }
            lineSet.add(line_items);
        }
        br.close();
        return tdata;
    }
}