Package quickml.supervised.classifier.splitOnAttribute

Source Code of quickml.supervised.classifier.splitOnAttribute.SplitOnAttributeClassifier

package quickml.supervised.classifier.splitOnAttribute;

import quickml.data.AttributesMap;
import quickml.data.PredictionMap;
import quickml.supervised.classifier.AbstractClassifier;
import quickml.supervised.classifier.Classifier;

import java.io.IOException;
import java.io.Serializable;
import java.util.Map;

/**
* Created by ian on 5/29/14.
*/
public class SplitOnAttributeClassifier extends AbstractClassifier {
    private static final long serialVersionUID = 2642074639257374588L;
    private final String attributeKey;
    private final Map<Serializable, Classifier> splitModels;
    private final Classifier defaultPM;

    public SplitOnAttributeClassifier(String attributeKey, final Map<Serializable, Classifier> splitModels, Classifier defaultPM) {
        this.attributeKey = attributeKey;
        this.splitModels = splitModels;
        this.defaultPM = defaultPM;
    }

    @Override
    public double getProbability(final AttributesMap attributes, final Serializable classification) {
        return getModelForAttributes(attributes).getProbability(attributes, classification);
    }

    @Override
    public PredictionMap predict(final AttributesMap attributes) {
        return getModelForAttributes(attributes).predict(attributes);
    }

    @Override
    public void dump(final Appendable appendable) {
        try {
            for (Map.Entry<Serializable, Classifier> splitModelEntry : splitModels.entrySet()) {
                appendable.append("Predictive model for " + attributeKey + "=" + splitModelEntry.getKey());
                splitModelEntry.getValue().dump(appendable);
            }
            appendable.append("Default");
            defaultPM.dump(appendable);
        } catch (IOException e) {
            throw new RuntimeException(e);
        }
    }

    @Override
    public Serializable getClassificationByMaxProb(final AttributesMap attributes) {
        return getModelForAttributes(attributes).getClassificationByMaxProb(attributes);
    }

    public Classifier getDefaultPM() {
        return defaultPM;
    }

    public Map<Serializable, Classifier> getSplitModels() {
        return splitModels;
    }

    private Classifier getModelForAttributes(AttributesMap attributes) {
        Serializable value = attributes.get(attributeKey);
        if (value == null) value = SplitOnAttributeClassifierBuilder.NO_VALUE_PLACEHOLDER;
        Classifier predictiveModel = splitModels.get(value);
        if (predictiveModel == null) {
            predictiveModel = defaultPM;
        }
        return predictiveModel;
    }

}
TOP

Related Classes of quickml.supervised.classifier.splitOnAttribute.SplitOnAttributeClassifier

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.