/*
* This file is part of ALOE.
*
* ALOE is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
* ALOE is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU General Public License for more details.
* You should have received a copy of the GNU General Public License
* along with ALOE. If not, see <http://www.gnu.org/licenses/>.
*
* Copyright (c) 2012 SCCL, University of Washington (http://depts.washington.edu/sccl)
*/
package etc.aloe.factories;
import etc.aloe.controllers.CrossValidationController;
import etc.aloe.controllers.LabelingController;
import etc.aloe.controllers.TrainingController;
import etc.aloe.cscw2013.DownsampleBalancing;
import etc.aloe.cscw2013.FeatureExtractionImpl;
import etc.aloe.cscw2013.FeatureGenerationImpl;
import etc.aloe.cscw2013.LabelMappingImpl;
import etc.aloe.cscw2013.NullSegmentation;
import etc.aloe.cscw2013.ResolutionImpl;
import etc.aloe.cscw2013.SMOFeatureWeighting;
import etc.aloe.cscw2013.ThresholdSegmentation;
import etc.aloe.cscw2013.TrainingImpl;
import etc.aloe.cscw2013.UpsampleBalancing;
import etc.aloe.cscw2013.WekaModel;
import etc.aloe.data.Model;
import etc.aloe.filters.StringToDictionaryVector;
import etc.aloe.options.InteractiveOptions;
import etc.aloe.options.LabelOptions;
import etc.aloe.options.ModeOptions;
import etc.aloe.options.SingleOptions;
import etc.aloe.options.TrainOptions;
import etc.aloe.processes.Balancing;
import etc.aloe.processes.FeatureExtraction;
import etc.aloe.processes.FeatureGeneration;
import etc.aloe.processes.FeatureWeighting;
import etc.aloe.processes.LabelMapping;
import etc.aloe.processes.SegmentResolution;
import etc.aloe.processes.Segmentation;
import etc.aloe.processes.Training;
import java.io.File;
import java.io.FileNotFoundException;
import java.text.DateFormat;
import java.text.SimpleDateFormat;
import java.util.List;
import org.kohsuke.args4j.Option;
/**
* Provides implementations for the CSCW 2013 pipeline.
*
* @author Michael Brooks <mjbrooks@uw.edu>
*/
public class CSCW2013 implements PipelineFactory {
public ModeOptions options;
@Override
public void initialize() {
double falseNegativeCost = 1;
double falsePositiveCost = 1;
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
falseNegativeCost = trainOpts.falseNegativeCost;
falsePositiveCost = trainOpts.falsePositiveCost;
} else if (options instanceof LabelOptionsImpl) {
LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
falseNegativeCost = labelOpts.falseNegativeCost;
falsePositiveCost = labelOpts.falsePositiveCost;
}
//Normalize the cost factors (sum to 2)
double costNormFactor = 0.5 * (falseNegativeCost + falsePositiveCost);
falseNegativeCost /= costNormFactor;
falsePositiveCost /= costNormFactor;
System.out.println("Costs normalized to " + falseNegativeCost + " (FN) " + falsePositiveCost + " (FP).");
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
trainOpts.falseNegativeCost = falseNegativeCost;
trainOpts.falsePositiveCost = falsePositiveCost;
} else if (options instanceof LabelOptionsImpl) {
LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
labelOpts.falseNegativeCost = falseNegativeCost;
labelOpts.falsePositiveCost = falsePositiveCost;
}
}
protected List<String> loadTermList(File emoticonFile) {
try {
return StringToDictionaryVector.readDictionaryFile(emoticonFile);
} catch (FileNotFoundException ex) {
System.err.println("Unable to read emoticon dictionary file " + emoticonFile);
System.err.println("\t" + ex.getMessage());
System.exit(1);
}
return null;
}
@Override
public Model constructModel() {
return new WekaModel();
}
@Override
public FeatureExtraction constructFeatureExtraction() {
return new FeatureExtractionImpl();
}
@Override
public FeatureGeneration constructFeatureGeneration() {
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
//Read the emoticons
List<String> termList = loadTermList(trainOpts.emoticonFile);
FeatureGenerationImpl featureGen = new FeatureGenerationImpl(termList);
featureGen.setParticipantFeatureCount(trainOpts.participantFeatures);
return featureGen;
} else {
throw new IllegalArgumentException("Options not for Training");
}
}
@Override
public LabelMapping constructLabelMapping() {
return new LabelMappingImpl();
}
@Override
public SegmentResolution constructSegmentResolution() {
return new ResolutionImpl();
}
@Override
public FeatureWeighting constructFeatureWeighting() {
return new SMOFeatureWeighting();
}
@Override
public Segmentation constructSegmentation() {
boolean disableSegmentation = false;
int segmentationThresholdSeconds = 30;
boolean ignoreParticipants = false;
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
disableSegmentation = trainOpts.disableSegmentation;
segmentationThresholdSeconds = trainOpts.segmentationThresholdSeconds;
ignoreParticipants = trainOpts.ignoreParticipants;
} else if (options instanceof LabelOptionsImpl) {
LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
disableSegmentation = labelOpts.disableSegmentation;
segmentationThresholdSeconds = labelOpts.segmentationThresholdSeconds;
ignoreParticipants = labelOpts.ignoreParticipants;
} else {
throw new IllegalArgumentException("Options should be for Training or Labeling");
}
if (disableSegmentation) {
return new NullSegmentation();
} else {
Segmentation segmentation = new ThresholdSegmentation(segmentationThresholdSeconds,
!ignoreParticipants);
segmentation.setSegmentResolution(new ResolutionImpl());
return segmentation;
}
}
@Override
public Training constructTraining() {
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
TrainingImpl trainingImpl = new TrainingImpl();
trainingImpl.setBuildLogisticModel(true);
if (trainOpts.useMinCost || trainOpts.useReweighting) {
trainingImpl.setUseCostTraining(true);
trainingImpl.setFalsePositiveCost(trainOpts.falsePositiveCost);
trainingImpl.setFalseNegativeCost(trainOpts.falseNegativeCost);
trainingImpl.setUseReweighting(trainOpts.useReweighting);
}
return trainingImpl;
} else {
throw new IllegalArgumentException("Options must be for Training");
}
}
@Override
public Balancing constructBalancing() {
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
if (trainOpts.useDownsampling) {
return new DownsampleBalancing(trainOpts.falsePositiveCost, trainOpts.falseNegativeCost);
} else if (trainOpts.useUpsampling) {
return new UpsampleBalancing(trainOpts.falsePositiveCost, trainOpts.falseNegativeCost);
} else {
return null;
}
} else {
throw new IllegalArgumentException("Options must be for Training");
}
}
@Override
public void configureLabeling(LabelingController labelingController) {
if (options instanceof LabelOptions) {
LabelOptionsImpl labelOpts = (LabelOptionsImpl) options;
//Options
labelingController.setCosts(labelOpts.falsePositiveCost, labelOpts.falseNegativeCost);
//Implementations
labelingController.setFeatureExtractionImpl(constructFeatureExtraction());
labelingController.setMappingImpl(constructLabelMapping());
}
}
@Override
public void configureCrossValidation(CrossValidationController crossValidationController) {
if (options instanceof TrainOptionsImpl) {
TrainOptionsImpl trainOpts = (TrainOptionsImpl) options;
//Implementations
crossValidationController.setFeatureGenerationImpl(this.constructFeatureGeneration());
crossValidationController.setFeatureExtractionImpl(this.constructFeatureExtraction());
crossValidationController.setTrainingImpl(this.constructTraining());
crossValidationController.setBalancingImpl(this.constructBalancing());
crossValidationController.setMappingImpl(this.constructLabelMapping());
//Options
crossValidationController.setFolds(trainOpts.crossValidationFolds);
crossValidationController.setCosts(trainOpts.falsePositiveCost, trainOpts.falseNegativeCost);
crossValidationController.setBalanceTestSet(trainOpts.balanceTestSet);
} else {
throw new IllegalArgumentException("Options must be for Training");
}
}
@Override
public void configureTraining(TrainingController trainingController) {
trainingController.setFeatureGenerationImpl(this.constructFeatureGeneration());
trainingController.setFeatureExtractionImpl(this.constructFeatureExtraction());
trainingController.setTrainingImpl(this.constructTraining());
trainingController.setFeatureWeightingImpl(this.constructFeatureWeighting());
trainingController.setBalancingImpl(this.constructBalancing());
}
@Override
public DateFormat constructDateFormat() {
return new SimpleDateFormat(options.dateFormatString);
}
@Override
public InteractiveOptions constructInteractiveOptions() {
return new InteractiveOptionsImpl();
}
@Override
public LabelOptions constructLabelOptions() {
return new LabelOptionsImpl();
}
@Override
public TrainOptions constructTrainOptions() {
return new TrainOptionsImpl();
}
@Override
public SingleOptions constructSingleOptions() {
return new SingleOptionsImpl();
}
@Override
public void setOptions(ModeOptions options) {
this.options = options;
}
static class InteractiveOptionsImpl extends InteractiveOptions {
}
static class SingleOptionsImpl extends SingleOptions {
}
static class LabelOptionsImpl extends LabelOptions {
@Option(name = "--fp-cost", usage = "the cost of a false positive (default 1)", metaVar = "COST")
public double falsePositiveCost = 1;
@Option(name = "--fn-cost", usage = "the cost of a false negative (default 1)", metaVar = "COST")
public double falseNegativeCost = 1;
@Option(name = "--ignore-participants", usage = "ignore participants during segmentation")
public boolean ignoreParticipants = false;
@Option(name = "--threshold", aliases = {"-t"}, usage = "segmentation threshold in seconds (default 30)", metaVar = "SECONDS")
public int segmentationThresholdSeconds = 30;
@Option(name = "--no-segmentation", usage = "disable segmentation (each message is in its own segment)")
public boolean disableSegmentation = false;
}
static class TrainOptionsImpl extends TrainOptions {
@Option(name="--participant-features", usage="use up to this many participant names as features")
public int participantFeatures = 0;
@Option(name = "--fp-cost", usage = "the cost of a false positive (default 1)", metaVar = "COST")
public double falsePositiveCost = 1;
@Option(name = "--fn-cost", usage = "the cost of a false negative (default 1)", metaVar = "COST")
public double falseNegativeCost = 1;
@Option(name = "--ignore-participants", usage = "ignore participants during segmentation")
public boolean ignoreParticipants = false;
@Option(name = "--threshold", aliases = {"-t"}, usage = "segmentation threshold in seconds (default 30)", metaVar = "SECONDS")
public int segmentationThresholdSeconds = 30;
@Option(name = "--no-segmentation", usage = "disable segmentation (each message is in its own segment)")
public boolean disableSegmentation = false;
@Option(name = "--upsample", aliases = {"-us"}, usage = "upsample the minority class in training sets to match the cost ratio")
public boolean useUpsampling = false;
@Option(name = "--reweight", aliases = {"-rw"}, usage = "reweight the training data")
public boolean useReweighting = false;
@Option(name = "--min-cost", usage = "train a classifier that uses the min-cost criterion")
public boolean useMinCost = false;
@Option(name = "--downsample", aliases = {"-ds"}, usage = "downsample the majority class in training sets to match the cost ratio")
public boolean useDownsampling = false;
@Option(name = "--folds", aliases = {"-k"}, usage = "number of cross-validation folds (default 10, 0 to disable cross validation)", metaVar = "FOLDS")
public int crossValidationFolds = 10;
@Option(name = "--balance-test-set", usage = "apply balancing to the test set as well as the training set")
public boolean balanceTestSet = false;
@Option(name = "--emoticons", aliases = {"-e"}, usage = "emoticon dictionary file (default emoticons.txt)")
public File emoticonFile = new File("emoticons.txt");
}
}