Package quickml.supervised.crossValidation

Source Code of quickml.supervised.crossValidation.OutOfTimeCrossValidator$TestDateTimeExtractor

package quickml.supervised.crossValidation;

import com.google.common.base.Optional;
import com.google.common.collect.Lists;
import org.joda.time.DateTime;
import org.joda.time.Period;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import quickml.data.AttributesMap;
import quickml.supervised.Utils;
import quickml.supervised.crossValidation.crossValLossFunctions.LabelPredictionWeight;
import quickml.supervised.crossValidation.crossValLossFunctions.LossWithModelConfiguration;
import quickml.supervised.crossValidation.crossValLossFunctions.MultiLossFunctionWithModelConfigurations;
import quickml.supervised.crossValidation.dateTimeExtractors.DateTimeExtractor;
import quickml.data.Instance;
import quickml.supervised.PredictiveModel;
import quickml.supervised.PredictiveModelBuilder;
import quickml.supervised.crossValidation.crossValLossFunctions.CrossValLossFunction;

import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.Map;

/**
* Created by alexanderhawk on 5/5/14.
*/
public class OutOfTimeCrossValidator<R, P> extends CrossValidator<R, P>{

    private static final Logger logger = LoggerFactory.getLogger(OutOfTimeCrossValidator.class);

    List<Instance<R>> allTrainingData;
    List<Instance<R>> trainingDataToAddToPredictiveModel;
    List<Instance<R>> validationSet;
    Optional<LabelConverter<R>> labelConverter = Optional.absent();

    final private CrossValLossFunction<P> crossValLossFunction;
    private double fractionOfDataForCrossValidation = 0.25;

    private final DateTimeExtractor<R> dateTimeExtractor;
    DateTime timeOfFirstInstanceInValidationSet;
    DateTime leastOuterBoundOfValidationSet;

    final Period durationOfValidationSet;
    private DateTime maxTime;
    private double weightOfValidationSet;
    private int currentTrainingSetSize = 0;
    int clicksInValSet = 0;

    public OutOfTimeCrossValidator(CrossValLossFunction<P> crossValLossFunction, double fractionOfDataForCrossValidation, int validationTimeSliceHours, DateTimeExtractor dateTimeExtractor) {
        this.crossValLossFunction = crossValLossFunction;
        this.fractionOfDataForCrossValidation = fractionOfDataForCrossValidation;
        this.dateTimeExtractor = dateTimeExtractor;
        this.durationOfValidationSet = new Period(validationTimeSliceHours, 0, 0, 0);
    }

    public OutOfTimeCrossValidator<R, P> labelConverter(LabelConverter<R> labelConverter) {
        this.labelConverter = Optional.of(labelConverter);
        return this;
    }

    @Override
    public <PM extends PredictiveModel<R, P>> double getCrossValidatedLoss(PredictiveModelBuilder<R, PM> predictiveModelBuilder, Iterable<? extends Instance<R>> rawTrainingData) {

        initializeTrainingAndValidationSets(rawTrainingData);

        double runningLoss = 0;
        double runningWeightOfValidationSet = 0;
        while (!validationSet.isEmpty()) {
            PM predictiveModel = predictiveModelBuilder.buildPredictiveModel(trainingDataToAddToPredictiveModel);
            List<LabelPredictionWeight<P>> labelPredictionWeights;
            List<Instance<R>> convertedValSet = validationSet;
            if (labelConverter.isPresent()) {
                convertedValSet = labelConverter.get().convertLabels(validationSet);
            }
            labelPredictionWeights = Utils.createLabelPredictionWeights(convertedValSet, predictiveModel);
            int positiveInstances  = 0;
            for (LabelPredictionWeight<P> labelPredictionWeight : labelPredictionWeights) {
                if (labelPredictionWeight.getLabel().equals(Double.valueOf(1.0)))
                    positiveInstances++;
            }

            runningLoss += crossValLossFunction.getLoss(labelPredictionWeights) * weightOfValidationSet;

            runningWeightOfValidationSet += weightOfValidationSet;
            logger.info("Running average Loss: " + runningLoss / runningWeightOfValidationSet + ", running weight: " + runningWeightOfValidationSet + ". pos instances: " + positiveInstances);

            updateTrainingSet();
            updateCrossValidationSet();
        }
        final double averageLoss = runningLoss / runningWeightOfValidationSet;
        logger.info("Average loss: " + averageLoss + ", runningWeight: " + runningWeightOfValidationSet);

        return averageLoss;
    }

    public <PM extends PredictiveModel<R, P>> MultiLossFunctionWithModelConfigurations getMultipleCrossValidatedLossesWithModelConfiguration(PredictiveModelBuilder<R, PM> predictiveModelBuilder, Iterable<? extends Instance<R>> rawTrainingData, MultiLossFunctionWithModelConfigurations<P> multiLossFunction) {

        initializeTrainingAndValidationSets(rawTrainingData);

        while (!validationSet.isEmpty()) {
            PM predictiveModel = predictiveModelBuilder.buildPredictiveModel(trainingDataToAddToPredictiveModel);
            List<LabelPredictionWeight<P>> labelPredictionWeights;
            List<Instance<R>> convertedValSet = validationSet;
            if (labelConverter.isPresent()) {
                convertedValSet = labelConverter.get().convertLabels(validationSet);
            }
            labelPredictionWeights = Utils.createLabelPredictionWeights(convertedValSet, predictiveModel);

            multiLossFunction.updateRunningLosses(labelPredictionWeights);
            updateTrainingSet();
            updateCrossValidationSet();
        }
        multiLossFunction.normalizeRunningAverages();
        Map<String, LossWithModelConfiguration> lossMap = multiLossFunction.getLossesWithModelConfigurations();
        for (String lossFunctionName : lossMap.keySet()) {
            logger.info("Loss function: " + lossFunctionName + "loss: " + lossMap.get(lossFunctionName).getLoss() + ".  Weight of val set: " + multiLossFunction.getRunningWeight());
        }
        return multiLossFunction;
    }


    private void initializeTrainingAndValidationSets(Iterable<? extends Instance<R>> rawTrainingData) {
        setAndSortAllTrainingData(rawTrainingData);
        setMaxValidationTime();

        int initialTrainingSetSize = getInitialSizeForTrainData();
        trainingDataToAddToPredictiveModel = Lists.<Instance<R>>newArrayListWithExpectedSize(initialTrainingSetSize);
        validationSet = Lists.<Instance<R>>newArrayList();

        DateTime timeOfInstance;

        timeOfFirstInstanceInValidationSet = dateTimeExtractor.extractDateTime(allTrainingData.get(initialTrainingSetSize));
        leastOuterBoundOfValidationSet = timeOfFirstInstanceInValidationSet.plus(durationOfValidationSet);

        weightOfValidationSet = 0;
        clicksInValSet = 0;
        for (Instance<R> instance : allTrainingData) {
            timeOfInstance = dateTimeExtractor.extractDateTime(instance);
            if (timeOfInstance.isBefore(timeOfFirstInstanceInValidationSet)) {
                trainingDataToAddToPredictiveModel.add(instance);
            } else if (timeOfInstance.isBefore(leastOuterBoundOfValidationSet)) {
                validationSet.add(instance);
                weightOfValidationSet += instance.getWeight();
                if (instance.getLabel().equals(Double.valueOf(1.0))) {
                    clicksInValSet++;
                }

            } else {
                break;
            }
        }
      //  logger.info("timeOfFirstInstanceInValidationSet: " + timeOfFirstInstanceInValidationSet.toString() + "\nleastOuterBoundOfValidationSet: " + leastOuterBoundOfValidationSet.toString());
      //  logger.info("initial clicks in valset: " + clicksInValSet);
        currentTrainingSetSize = trainingDataToAddToPredictiveModel.size();
    }


    private void updateTrainingSet() {
        trainingDataToAddToPredictiveModel = validationSet;
        currentTrainingSetSize += trainingDataToAddToPredictiveModel.size();
    }

    private void updateCrossValidationSet() {
        clearValidationSet();
        if (!newValidationSetExists()) {
            return;
        }
        timeOfFirstInstanceInValidationSet = leastOuterBoundOfValidationSet;
        leastOuterBoundOfValidationSet = timeOfFirstInstanceInValidationSet.plus(durationOfValidationSet);
   //     logger.info("first val set instance: " + timeOfFirstInstanceInValidationSet + "\n time of last instance " + leastOuterBoundOfValidationSet);

        while(validationSet.isEmpty()) {
            for (int i = currentTrainingSetSize; i < allTrainingData.size(); i++) {
                Instance<R> instance = allTrainingData.get(i);
                DateTime timeOfInstance = dateTimeExtractor.extractDateTime(instance);
                if (timeOfInstance.isBefore(leastOuterBoundOfValidationSet)) {
                    validationSet.add(instance);
                    weightOfValidationSet += instance.getWeight();
                    if (instance.getLabel().equals(Double.valueOf(1.0)))
                        clicksInValSet++;
                } else
                    break;
            }
        }
        logger.info("clicks in val set: " + clicksInValSet);
    }

    private void clearValidationSet() {
        weightOfValidationSet = 0;
        validationSet = Lists.<Instance<R>>newArrayList();
    }


    private void setMaxValidationTime() {
        Instance<R> latestInstance = allTrainingData.get(allTrainingData.size() - 1);
        maxTime = dateTimeExtractor.extractDateTime(latestInstance);
    }

    private int getInitialSizeForTrainData() {
        int initialTrainingSetSize = (int) (allTrainingData.size() * (1 - fractionOfDataForCrossValidation));
        verifyInitialValidationSetExists(initialTrainingSetSize);
        return initialTrainingSetSize;
    }


    private void verifyInitialValidationSetExists(int initialTrainingSetSize) {
        if (initialTrainingSetSize == allTrainingData.size()) {
            throw new RuntimeException("fractionOfDataForCrossValidation must be non zero");
        }
    }

    private boolean newValidationSetExists() {
        return currentTrainingSetSize < allTrainingData.size();
    }

    private void setAndSortAllTrainingData(Iterable<? extends Instance<R>> rawTrainingData) {

        this.allTrainingData = Lists.<Instance<R>>newArrayList();
        for (Instance<R> instance : rawTrainingData) {
            this.allTrainingData.add(instance);
        }

        Comparator<Instance<R>> comparator = new Comparator<Instance<R>>() {
            @Override
            public int compare(Instance<R> o1, Instance<R> o2) {
                DateTime firstInstance = dateTimeExtractor.extractDateTime(o1);
                DateTime secondInstance = dateTimeExtractor.extractDateTime(o2);
                if (firstInstance.isAfter(secondInstance)) {
                    return 1;
                } else if (firstInstance.isEqual(secondInstance)) {
                    return 0;
                } else {
                    return -1;
                }
            }
        };

        Collections.sort(this.allTrainingData, comparator);
    }

    static class TestDateTimeExtractor implements DateTimeExtractor<AttributesMap> {
        @Override
        public DateTime extractDateTime(Instance<AttributesMap> instance) {
            AttributesMap attributes = instance.getAttributes();
            int year = ((Long) attributes.get("timeOfArrival-year")).intValue();
            int month = ((Long) attributes.get("timeOfArrival-monthOfYear")).intValue();
            int day = ((Long) attributes.get("timeOfArrival-dayOfMonth")).intValue();
            int hour = ((Long) attributes.get("timeOfArrival-hourOfDay")).intValue();
            int minute = ((Long) attributes.get("timeOfArrival-minuteOfHour")).intValue();
            return new DateTime(year, month, day, hour, minute, 0, 0);
        }
    }

}
TOP

Related Classes of quickml.supervised.crossValidation.OutOfTimeCrossValidator$TestDateTimeExtractor

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.