if (trainOption.value)
{
p.setTargetProcessing(true);
trainingData = new InstanceList(p);
trainingData.addThruPipe(
new LineGroupIterator(trainingFile,
Pattern.compile("^\\s*$"), true));
logger.info
("Number of features in training data: "+p.getDataAlphabet().size());
if (testOption.value != null)
{
if (testFile != null)
{
testData = new InstanceList(p);
testData.addThruPipe(
new LineGroupIterator(testFile,
Pattern.compile("^\\s*$"), true));
}
else
{
Random r = new Random (randomSeedOption.value);
InstanceList[] trainingLists =
trainingData.split(
r, new double[] {trainingFractionOption.value,
1-trainingFractionOption.value});
trainingData = trainingLists[0];
testData = trainingLists[1];
}
}
} else if (testOption.value != null)
{
p.setTargetProcessing(true);
testData = new InstanceList(p);
testData.addThruPipe(
new LineGroupIterator(testFile,
Pattern.compile("^\\s*$"), true));
} else
{
p.setTargetProcessing(false);
testData = new InstanceList(p);
//testData.addThruPipe(
// new LineGroupIterator(testFile,
// Pattern.compile("^\\s*$"), true));
}
//logger.info ("Number of predicates: "+p.getDataAlphabet().size());
if (testOption.value != null)
{
if (testOption.value.startsWith("lab"))
eval = new TokenAccuracyEvaluator(new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"});
else if (testOption.value.startsWith("seg="))
{
String[] pairs = testOption.value.substring(4).split(",");
if (pairs.length < 1)
{
commandOptions.printUsage(true);
throw new IllegalArgumentException(
"Missing segment start/continue labels: " + testOption.value);
}
String startTags[] = new String[pairs.length];
String continueTags[] = new String[pairs.length];
for (int i = 0; i < pairs.length; i++)
{
String[] pair = pairs[i].split("\\.");
if (pair.length != 2)
{
commandOptions.printUsage(true);
throw new
IllegalArgumentException(
"Incorrectly-specified segment start and end labels: " +
pairs[i]);
}
startTags[i] = pair[0];
continueTags[i] = pair[1];
}
eval = new MultiSegmentationEvaluator(new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"},
startTags, continueTags);
}
else
{
commandOptions.printUsage(true);
throw new IllegalArgumentException("Invalid test option: " +
testOption.value);
}
}
if (p.isTargetProcessing())
{
Alphabet targets = p.getTargetAlphabet();
StringBuffer buf = new StringBuffer("Labels:");
for (int i = 0; i < targets.size(); i++)
buf.append(" ").append(targets.lookupObject(i).toString());
logger.info(buf.toString());
}
if (trainOption.value)
{
crf = train(trainingData, testData, eval,
ordersOption.value, defaultOption.value,
forbiddenOption.value, allowedOption.value,
connectedOption.value, iterationsOption.value,
gaussianVarianceOption.value, crf);
if (modelOption.value != null)
{
ObjectOutputStream s =
new ObjectOutputStream(new FileOutputStream(modelOption.value));
s.writeObject(crf);
s.close();
}
}
else
{
if (crf == null)
{
if (modelOption.value == null)
{
commandOptions.printUsage(true);
throw new IllegalArgumentException("Missing model file option");
}
ObjectInputStream s =
new ObjectInputStream(new FileInputStream(modelOption.value));
crf = (CRF) s.readObject();
s.close();
}
if (eval != null)
test(new NoopTransducerTrainer(crf), eval, testData);
else
{
boolean includeInput = includeInputOption.value();
Scanner scanner = new Scanner(System.in);
Pattern pattern = Pattern.compile("^\\s*$");
int nLines = 0;
while (scanner.hasNextLine()) {
String line = scanner.nextLine();
line = line.replace('\t', '\n');
testData = new InstanceList(p);
testData.addThruPipe(new LineGroupIterator(new StringReader(line),
pattern, true));
for (int i = 0; i < testData.size(); i++) {
Sequence input = (Sequence)testData.get(i).getData();
Sequence[] outputs = apply(crf, input, nBestOption.value);