项目需要分布式的代码,于是先看了下算法思想,自己学着单机实现了一下。复杂度比较高,其中有两个递归实现。由于Java提供的Set都没有实现Comparable接口,所以自定义了一个Set叫做SimpleSet,其中实现了hashCode, equals, compareTo等函数。
以下代码完全没有指导意义,只作为个人工作记录。
代码如下:
SimpleSet.java
import java.io.*;
import java.util.*;
public class SimpleSet<T> implements Set, Comparable{
static int i =0;
Set<T> set = new HashSet<T>();
public SimpleSet() {
}
public SimpleSet(SimpleSet ss) {
this.set.addAll(ss.set);
}
public int hashCode() {
int type = this.getClass().hashCode();
int code = 0;
for(T t : this.set) {
code += (type*31 + t.hashCode());
}
return code;
}
public boolean equals(Object obj) {
boolean flag = false;
if(obj instanceof SimpleSet) {
SimpleSet<?> s = (SimpleSet<?>)obj;
if( (s.size()==this.size()) && s.containsAll(this.set)&& this.set.containsAll(s)) {
flag = true;
}
}
return flag;
}
public String toString() {
if (this.set.size() == 0) {
return "";
} else {
StringBuilder str = new StringBuilder();
str.append("[");
for (T t : this.set) {
str.append(t.toString() + ",");
}
String result = str.substring(0, str.length() - 1);
result += "]";
return result;
}
}
@Override
public boolean add(Object e) {
return this.set.add((T) e);
}
@Override
public boolean addAll(Collection c) {
return this.set.addAll(c);
}
@Override
public void clear() {
this.set.clear();
}
@Override
public boolean contains(Object o) {
return this.set.contains(o);
}
@Override
public boolean containsAll(Collection c) {
return this.set.containsAll(c);
}
@Override
public boolean isEmpty() {
return this.set.isEmpty();
}
@Override
public Iterator iterator() {
return this.set.iterator();
}
@Override
public boolean remove(Object o) {
return this.set.remove(o);
}
@Override
public boolean removeAll(Collection c) {
return this.set.removeAll(c);
}
@Override
public boolean retainAll(Collection c) {
return this.set.retainAll(c);
}
@Override
public int size() {
return this.set.size();
}
@Override
public Object[] toArray() {
return this.set.toArray();
}
public String[] toArray(String[] a) {
return this.set.toArray(a);
}
@Override
public Object[] toArray(Object[] a) {
// TODO DO NOT USE
return null;
}
@Override
public int compareTo(Object o) {
return this.hashCode()-o.hashCode();
}
}
Apriori.java
import java.util.*;
import java.io.*;
public class Apriori {
/**
* @param args
*/
static Map<String, Integer> data = new HashMap<String, Integer>();
static List<SimpleSet<String>> lineSet = new ArrayList<SimpleSet<String>>();
// all constructible combinations
static Map<SimpleSet<String>, Integer> dataSet = new HashMap<SimpleSet<String>, Integer>();
static final int MIN_SUP = 1;
static final float MIN_CONF = 0.1f;
static BufferedWriter bw1= null;
static BufferedWriter bw2 = null;
public static void main(String[] args) throws IOException {
bw1= new BufferedWriter(new FileWriter("FrequentPatterns.txt", true));
bw2 = new BufferedWriter(new FileWriter("AssociationRules.txt", true));
long startTime = System.currentTimeMillis();
String srcFile = "Test.txt";
try {
data = buildData(srcFile);
} catch (IOException e) {
e.printStackTrace();
}
Map<SimpleSet<String>, Integer> result = getF1Set(data);
Map<SimpleSet<String>, Integer> maxFICMap = new HashMap<SimpleSet<String>, Integer>();
int i = 1;
do {
printCandidate(result, i);
i++;
maxFICMap = result;
result = recurseGen(result);
} while (result.size() > 0);
System.out.println("---------------------------Frequent Patterns---------------------------\n");
long endTime1 = System.currentTimeMillis();
System.out.println((endTime1 - startTime)+" ms");
//System.out.println(dataSet);
//System.out.println(maxFICMap);
List<SimpleSet<String>> maxFI = new ArrayList<SimpleSet<String>>(
maxFICMap.keySet());
System.out.println("---------------------------Association Rule---------------------------\n");
calcAndPrintAsso(maxFI);
long endTime2 = System.currentTimeMillis();
System.out.println((endTime2 - endTime1)+" ms");
System.out.println("----------------------------Consumed Time----------------------------\n");
long endTime3 = System.currentTimeMillis();
System.out.println((endTime3 - startTime)+" ms");
bw1.close();
bw2.close();
}
private static void calcAndPrintAsso(List<SimpleSet<String>> maxFI) throws IOException {
for (SimpleSet<String> oneBigSet : maxFI) {
if (oneBigSet.size() > 1) {
if (dataSet.containsKey(oneBigSet)) {
int numerator = dataSet.get(oneBigSet);
List<SimpleSet<String>> subs = getSubsets(oneBigSet);
for (SimpleSet<String> sub : subs) {
if (dataSet.containsKey(sub)) {
int denominator = dataSet.get(sub);
float confidence = (float) numerator / denominator;
bw2.write(sub.toString() + " --> " + oneBigSet.toString()
+ " conf=" + confidence);
bw2.newLine();
}
}
calcAndPrintAsso(subs);
}
}
}
}
private static void printCandidate(Map<SimpleSet<String>, Integer> result,
int i) throws IOException {
int size = result.size();
bw1.write(i + "-item set (num: " + size + "):\n");
bw1.write(result.toString());
bw1.newLine();
}
private static Map<SimpleSet<String>, Integer> recurseGen(
Map<SimpleSet<String>, Integer> preMap) {
List<SimpleSet<String>> keyList = new ArrayList<SimpleSet<String>>();
keyList.addAll(preMap.keySet());
Map<SimpleSet<String>, Integer> result = new HashMap<SimpleSet<String>, Integer>();
int preSize = keyList.size();
for (int i = 0; i < preSize - 1; i++) {
for (int j = i + 1; j < preSize; j++) {
SimpleSet<String> key1 = keyList.get(i);
SimpleSet<String> key2 = keyList.get(j);
String[] pre1 = key1.toArray(new String[0]);
String[] pre2 = key2.toArray(new String[0]);
if (maybeLinkable(pre1, pre2)) {
SimpleSet<String> superSet = new SimpleSet<String>();
int count = 0;
for (String str : pre1) {
superSet.add(str);
}
superSet.add(pre2[pre2.length - 1]);
if (!shouldCut(keyList, superSet)) {
count = checkSup(superSet);
if (count >= MIN_SUP) {
result.put(superSet, count);
dataSet.put(superSet, count);// construct global
// reference
}
}
}
}
}
return result;
}
private static int checkSup(SimpleSet<String> superSet) {
int count = 0;
for (SimpleSet<String> line : lineSet) {
if (line.containsAll(superSet)) {
count++;
}
}
return count;
}
private static boolean shouldCut(List<SimpleSet<String>> preKeyList,
SimpleSet<String> superSet) {
boolean flag = false;
List<SimpleSet<String>> subsets = getSubsets(superSet);
for (SimpleSet<String> s : subsets) {
if (!preKeyList.contains(s)) {
flag = true;
break;
}
}
return flag;
}
private static List<SimpleSet<String>> getSubsets(SimpleSet<String> superSet) {
List<SimpleSet<String>> subsets = new ArrayList<SimpleSet<String>>();
String[] superList = superSet.toArray(new String[0]);
for (int i = 0; i < superList.length; i++) {
SimpleSet<String> copyOfSuperSet = new SimpleSet<String>(superSet);
copyOfSuperSet.remove(superList[i]);
/*
* SimpleSet<String> temp = new SimpleSet<String>();
*
* for (int j = 0; j < superList.length; j++) { if (i != j) {
* temp.add(superList[i]); } }
*/subsets.add(copyOfSuperSet);
}
return subsets;
}
private static boolean maybeLinkable(String[] pre1, String[] pre2) {
boolean flag = true;
int size1 = pre1.length;
int size2 = pre2.length;
if (size1 == size2) {
if (size1 != 0) {
for (int i = 0; i < size1 - 1; i++) {
if (!pre1[i].equals(pre2[i])) {
flag = false;
break;
}
}
if (pre1[size1 - 1].equals(pre2[size2 - 1])) {
flag = false;
}
} else {
flag = false;
}
} else {
flag = false;
}
return flag;
}
private static HashMap<SimpleSet<String>, Integer> getF1Set(
Map<String, Integer> srcdata) {
HashMap<SimpleSet<String>, Integer> f1Set = new HashMap<SimpleSet<String>, Integer>();
for (String str : srcdata.keySet()) {
int count = srcdata.get(str);
if (count >= MIN_SUP) {
SimpleSet<String> set = new SimpleSet<String>();
set.add(str);
f1Set.put(set, count);// Be Strict !
dataSet.put(set, count);
}
}
return f1Set;
}
static Map<String, Integer> buildData(String... files) throws IOException {
Map<String, Integer> tdata = new HashMap<String, Integer>();
BufferedReader br = new BufferedReader(new FileReader(files[0]));
String line;
String[] col;
while ((line = br.readLine()) != null) {
SimpleSet<String> line_items = new SimpleSet<String>();
col = line.split(" ");
for (String str : col) {
line_items.add(str);
if (!tdata.containsKey(str)) {
tdata.put(str, 1);
} else {
int i = tdata.get(str);
tdata.put(str, i + 1);// TODO
}
}
lineSet.add(line_items);
}
br.close();
return tdata;
}
}