//System.err.println("Used Memory "+String.format("%.3f", (Runtime.getRuntime().totalMemory()-Runtime.getRuntime().freeMemory())/1000000.) + " before lattice");
ArrayList<SumLattice> lattices = new ArrayList<SumLattice>();
if (numThreads == 1) {
for (int ii = 0; ii < data.size(); ii++) {
if (instancesWithConstraints.get(ii)) {
SumLatticeDefault lattice = new SumLatticeDefault(
this.crf, (FeatureVectorSequence)data.get(ii).getData(),
null, null, true);
lattices.add(lattice);
}
else {
lattices.add(null);
}
}
}
else {
// mutli-threaded version
ArrayList<Callable<Void>> tasks = new ArrayList<Callable<Void>>();
if (data.size() < numThreads) {
numThreads = data.size();
}
int increment = data.size() / numThreads;
int start = 0;
int end = increment;
for (int thread = 0; thread < numThreads; thread++) {
tasks.add(new SumLatticeTask(crf,data,instancesWithConstraints,start,end));
start += increment;
if (thread == numThreads - 2) {
end = data.size();
}
else {
end += increment;
}
}
try {
// run all threads and wait for them to finish
executor.invokeAll(tasks);
} catch (InterruptedException ie) {
ie.printStackTrace();
}
for (Callable<Void> task : tasks) {
lattices.addAll(((SumLatticeTask)task).getLattices());
}
assert(lattices.size() == data.size()) : lattices.size() + " " + data.size();
}
System.err.println("Done computing lattices.");
for (GEConstraint constraint : constraints) {
constraint.zeroExpectations();
constraint.computeExpectations(lattices);
}
System.err.println("Done computing expectations.");
//System.gc();
//System.err.println("Used Memory "+String.format("%.3f", (Runtime.getRuntime().totalMemory()-Runtime.getRuntime().freeMemory())/1000000.) + " after lattice");
// compute GE value
this.cachedValue = 0;
for (GEConstraint constraint : constraints) {
this.cachedValue += constraint.getValue();
}
cachedGradient.zero();
// compute GE gradient
if (numThreads == 1) {
for (int ii = 0; ii < data.size(); ii++) {
if (instancesWithConstraints.get(ii)) {
SumLattice lattice = lattices.get(ii);
FeatureVectorSequence fvs = (FeatureVectorSequence)data.get(ii).getData();
new GELattice(fvs, lattice.getGammas(), lattice.getXis(), crf, reverseTrans, reverseTransIndices, cachedGradient,this.constraints, false);
}
}
}
else {
// multi-threaded version