static Object[] chooseBestTestAttribute(Map<String, List<Sample>> categoryToSamples, String[] attrNames) {
int minIndex = -1;
double minValue = Double.MAX_VALUE;
Map<String, Map<String, List<Sample>>> minSplits = null;
for (int attrIndex = 0; attrIndex < attrNames.length; attrIndex++) {
int allCount = 0;
Map<String, Map<String, List<Sample>>> curSplits = new HashMap<>();
for (Entry<String, List<Sample>> entry : categoryToSamples.entrySet()) {
Object category = entry.getKey();
List<Sample> samples = entry.getValue();
for (Sample sample : samples) {
Object attrValue = sample.getAttribute(attrNames[attrIndex]);
Map<String, List<Sample>> split = curSplits.get(attrValue);
if (split == null) {
split = new HashMap<>();
curSplits.put(attrValue, split);
}
List<Sample> splitSamples = split.get(category);
if (splitSamples == null) {
splitSamples = new LinkedList<>();
split.put(category, splitSamples);
}
splitSamples.add(sample);
}
allCount += samples.size();
}
double curValue = 0.0;
for (Map<String, List<Sample>> splits : curSplits.values()) {
double perSplitCount = 0;
for (List<Sample> list : splits.values()) perSplitCount += list.size();
double perSplitValue = 0.0;
for (List<Sample> list : splits.values()) {
double p = list.size() / perSplitCount;
perSplitValue -= p * (Math.log(p) / Math.log(2));
}
curValue += (perSplitCount / allCount) * perSplitValue;
}
if (minValue > curValue) {
minIndex = attrIndex;
minValue = curValue;
minSplits = curSplits;
}
}
return new Object[]{minIndex, minValue, minSplits};
}
public static void main(String[] args) throws Exception {
String[] attrNames = new String[] { "AGE", "INCOME", "STUDENT", "CREDIT_RATING" };
Map<String, List<Sample>> samples = readSamples(attrNames);
Object decisionTree = generateDecisionTree(samples, attrNames);
outputDecisionTree(decisionTree, 0, null);
}
static Map<String, List<Sample>> readSamples(String[] attrNames) {
Object[][] rawData = new Object[][] {
{ "<30 ", "High ", "No ", "Fair ", "0" },
{ "<30 ", "High ", "No ", "Excellent", "0" },
{ "30-40", "High ", "No ", "Fair ", "1" },
{ ">40 ", "Medium", "No ", "Fair ", "1" },
{ ">40 ", "Low ", "Yes", "Fair ", "1" },
{ ">40 ", "Low ", "Yes", "Excellent", "0" },
{ "30-40", "Low ", "Yes", "Excellent", "1" },
{ "<30 ", "Medium", "No ", "Fair ", "0" },
{ "<30 ", "Low ", "Yes", "Fair ", "1" },
{ ">40 ", "Medium", "Yes", "Fair ", "1" },
{ "<30 ", "Medium", "Yes", "Excellent", "1" },
{ "30-40", "Medium", "No ", "Excellent", "1" },
{ "30-40", "High ", "Yes", "Fair ", "1" },
{ ">40 ", "Medium", "No ", "Excellent", "0" }
};
Map<String, List<Sample>> ret = new HashMap<>();
for (Object[] row : rawData) {
Sample sample = new Sample();
int n = row.length;
for (int i = 0; i < n - 1; i++) sample.setAttribute(attrNames[i], row[i]);
sample.setCategory(row[n - 1]);
List<Sample> list = ret.get(row[n - 1]);
if (list == null) {
list = new LinkedList<>();
ret.put(row[n - 1], list);
}
list.add(sample);
}
return ret;
}
static Object generateDecisionTree(Map<String, List<Sample>> categoryToSamples, String[] attrNames) {
if (categoryToSamples.size() == 1) return categoryToSamples.keySet().iterator().next();
if (attrNames.length == 0) {
int max = 0;
Object maxCategory = null;
for (Entry<String, List<Sample>> entry : categoryToSamples.entrySet()) {
int cur = entry.getValue().size();
if (cur > max) { max = cur; maxCategory = entry.getKey(); }
}
return maxCategory;
}
Object[] rst = chooseBestTestAttribute(categoryToSamples, attrNames);
Tree tree = new Tree(attrNames[(Integer) rst[0]]);
String[] subA = new String[attrNames.length - 1];
for (int i = 0, j = 0; i < attrNames.length; i++)
if (i != (Integer) rst[0]) subA[j++] = attrNames[i];
@SuppressWarnings("unchecked")
Map<String, Map<String, List<Sample>>> splits = (Map<String, Map<String, List<Sample>>>) rst[2];
for (Entry<String, Map<String, List<Sample>>> entry : splits.entrySet()) {
Object attrValue = entry.getKey();
Map<String, List<Sample>> split = entry.getValue();
Object child = generateDecisionTree(split, subA);
tree.setChild(attrValue, child);
}
return tree;
}