Package com.blazebit.ai.decisiontree.impl

Source Code of com.blazebit.ai.decisiontree.impl.ID3AttributeSelector$Pair

package com.blazebit.ai.decisiontree.impl;

import com.blazebit.ai.decisiontree.Attribute;
import com.blazebit.ai.decisiontree.AttributeSelector;
import com.blazebit.ai.decisiontree.AttributeValue;
import com.blazebit.ai.decisiontree.DiscreteAttribute;
import com.blazebit.ai.decisiontree.Example;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;

/**
*
* @author Christian Beikov
*/
public class ID3AttributeSelector implements AttributeSelector<Boolean> {

    @Override
    public Attribute select(final Set<Example<Boolean>> examples, final Set<Attribute> availableAttributes, final Set<Attribute> usedAttributes) {
        Attribute attribute = null;
        float attributeRem = Float.MAX_VALUE;
        int attributeValueCount = Integer.MAX_VALUE;
        float positives = 0;
        float negatives = 0;

        final Map<Attribute, Map<AttributeValue, Pair>> attributeUsage = new HashMap<Attribute, Map<AttributeValue, Pair>>();
       
        /* Make array for performance */
        final Example<Boolean>[] exampleArray = examples.toArray(new Example[0]);
        final int examplesSize = exampleArray.length;
       
        for (final Attribute attr : availableAttributes) {
            if (usedAttributes.contains(attr)) {
                continue;
            }

            final Map<AttributeValue, Pair> valueUsage = new HashMap<AttributeValue, Pair>();
            attributeUsage.put(attr, valueUsage);

            for (int i = 0; i < examplesSize; i++) {
                final AttributeValue value = exampleArray[i].getValues().get(attr);
                Pair valueUsageExamples = valueUsage.get(value);

                if (valueUsageExamples == null) {
                    valueUsageExamples = new Pair();
                    valueUsage.put(value, valueUsageExamples);
                }

                if (exampleArray[i].getResult()) {
                    ++valueUsageExamples.positive;
                    ++positives;
                } else {
                    ++valueUsageExamples.negative;
                    ++negatives;
                }
            }
        }

        if (positives > 0 && negatives > 0) {
            for (final Map.Entry<Attribute, Map<AttributeValue, Pair>> entry : attributeUsage.entrySet()) {
                final Attribute attr = entry.getKey();
                final float rem = Pair.rem(entry.getValue().values(), positives, negatives);

                if (rem < attributeRem) {
                    attribute = attr;
                    attributeRem = rem;

                    if (attr instanceof DiscreteAttribute) {
                        attributeValueCount = ((DiscreteAttribute) attr).getValues().size();
                    }
                } else if (attr instanceof DiscreteAttribute && (rem == attributeRem) && ((DiscreteAttribute) attr).getValues().size() < attributeValueCount) {
                    attribute = attr;
                    attributeRem = rem;
                    attributeValueCount = ((DiscreteAttribute) attr).getValues().size();
                }
            }
        }

        return attribute;
    }

    private static class Pair {

        static final float log2 = (float) Math.log(2);
        float positive = 0;
        float negative = 0;

        double entropy() {
            final float localPositive = positive;
            final float localNegative = negative;
            final float localLog2 = log2;

            if (localPositive == 0 || localNegative == 0) {
                return 0;
            }

            final float p = localPositive / (localPositive + localNegative);
            return -p * (Math.log(p) / localLog2) - (1 - p) * (Math.log(1 - p) / localLog2);
        }

        static float rem(final Collection<Pair> pairs, final float positives, final float negatives) {
            float rem = 0;

            for (final Pair p : pairs) {
                rem += ((p.positive + p.negative) / (positives + negatives)) * p.entropy();
            }

            return rem;
        }
    }
}
TOP

Related Classes of com.blazebit.ai.decisiontree.impl.ID3AttributeSelector$Pair

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.