package gannuWSD.testing;
import gannuNLP.corpus.WikiCorpus;
import gannuNLP.data.AmbiguousWord;
import gannuNLP.data.Input;
import gannuNLP.dictionaries.DataBroker;
import gannuUtil.KeyArray;
import gannuUtil.KeyString;
import gannuUtil.Util;
import gannuWSD.algorithms.WSDAlgorithm;
import gannuWSD.bowmodifiers.BoWModifier;
import gannuWSD.sensefilters.FirstSenses;
import gannuWSD.sensefilters.NthSenseOnly;
import gannuWSD.sensefilters.RemoveNthSense;
import gannuWSD.sensefilters.SenseFilter;
import gannuWSD.skipfilters.SkipFilter;
import gannuWSD.windowfilters.WindowFilter;
import java.io.BufferedWriter;
import java.io.File;
import java.io.FileWriter;
import java.util.ArrayList;
import javax.xml.parsers.DocumentBuilder;
import javax.xml.parsers.DocumentBuilderFactory;
import org.w3c.dom.Document;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;
/**
* Class used for running experiments as specified in an XML file.
* @author Francisco Viveros-Jiménez
*
*/
public class TestSet {
/**
* Loads a configuration file and prepares everything for just running the tests.
* @param XMLfile Configuration file.
* @throws Exception
*/
public static void runTests(String XMLfile)throws Exception
{
//Load the XML
DocumentBuilderFactory fact = DocumentBuilderFactory.newInstance();
DocumentBuilder builder = fact.newDocumentBuilder();
Document testset=builder.parse(XMLfile);
Element root=testset.getDocumentElement();
//Initialize the DataBroker
Element data=(Element)root.getElementsByTagName("dict").item(0);
DataBroker dict=new DataBroker(data.getAttribute("connector"),data.getAttribute("version"));
dict.setPath(data.getAttribute("path"));
dict.getSource().setPath(data.getAttribute("path"));
String sources=data.getAttribute("sources");
dict.load(sources);
NodeList bowMods=root.getElementsByTagName("bowmodifier");
for(int i=0;i<bowMods.getLength();i++)
{
Element bowMod=(Element) bowMods.item(i);
BoWModifier mod=(BoWModifier) Class.forName(bowMod.getAttribute("class")).newInstance();
if(bowMod.getAttribute("config")!=null)
{
mod.addParameters(bowMod.getAttribute("config"));
}
mod.setDict(dict);
mod.init();
dict.addModifier(mod);
}
//Load the test sets
System.out.println("Loading test sets");
NodeList docs=root.getElementsByTagName("testset");
ArrayList<String> testsetsnames=new ArrayList<String>(docs.getLength());
ArrayList<String> sensefilters=new ArrayList<String>(docs.getLength());
ArrayList<ArrayList<Input>> testsets=new ArrayList<ArrayList<Input>>(docs.getLength());
ArrayList<String> prefixes=new ArrayList<String>(docs.getLength());
ArrayList<Input> ins=new ArrayList<Input>();
for(int i=0;i<docs.getLength();i++)
{
//Load the sense filter
SenseFilter filter=null;
Element doc=(Element)docs.item(i);
String sfilter=doc.getAttribute("senses");
if(sfilter!=null)
{
sensefilters.add(sfilter);
if(sfilter.startsWith("+"))
filter=new FirstSenses("N:"+sfilter.substring(1));
if(sfilter.startsWith("*"))
filter=new NthSenseOnly("N:"+sfilter.substring(1));
if(sfilter.startsWith("-"))
filter=new RemoveNthSense("N:"+sfilter.substring(1));
}
boolean noTag=false;
if(doc.getAttribute("includeNoTags")!=null)
noTag=Boolean.parseBoolean(doc.getAttribute("includeNoTags"));
prefixes.add(doc.getAttribute("output"));
testsetsnames.add(doc.getAttribute("path"));
//Load the tests
ArrayList<File> files=Util.getAllSGFFiles(new File(doc.getAttribute("path")));
ArrayList<Input> inputs=new ArrayList<Input>(files.size());
int x=1;
for(File file:files)
{
System.out.println("Loading "+file.getName()+" "+String.valueOf(x)+"/"+String.valueOf(files.size()));
x++;
File tmp=new File("./data/inputs/"+dict.getCompleteName().replace(">", "@@@@@@")+"/"+file.getName()+"_"+noTag);
Input in;
if(tmp.exists())
{
in=(Input)Util.loadObject(tmp);
inputs.add(in);
ins.add(in);
for(AmbiguousWord word:in.getAmbiguousWords())
{
word.setDict(dict);
}
}
else
{
File d=new File("./data/inputs/"+dict.getCompleteName().replace(">", "@@@@@@")+"/");
if(!d.exists())
d.mkdirs();
in=new Input(file,filter,dict,noTag,true);
inputs.add(in);
ins.add(in);
Util.writeObject(tmp, in);
}
}
testsets.add(inputs);
}
//Load the xls data.
Element xls=(Element)root.getElementsByTagName("output").item(0);
String src=xls.getAttribute("path");
File output=new File(src);
ArrayList<String> Rscripts=new ArrayList<String>();
ArrayList<String> filenames=new ArrayList<String>();
if(!output.exists())
{
output.mkdirs();
}
//Cycle through all the algorithm nodes
NodeList algorithms=root.getElementsByTagName("algorithm");
ArrayList<ArrayList<Test>> tests=new ArrayList<ArrayList<Test>>(testsets.size());
int tindex=0;
for(ArrayList<Input> inputs:testsets)
{
ArrayList<Test> ts=new ArrayList<Test>();
for(int a=0;a<algorithms.getLength();a++)
{
Element algo=(Element)algorithms.item(a);
//Initialize all the algorithms
Element wsdNode=(Element)algo.getElementsByTagName("wsd").item(0);
Element backoffNode=(Element)algo.getElementsByTagName("backoff").item(0);
Element tieNode=(Element)algo.getElementsByTagName("tie").item(0);
//create all the possible algorithm/backoff/tie configurations.
ArrayList<String> wsdconfig=TestSet.getConfigurationStrings(wsdNode.getElementsByTagName("param"));
ArrayList<String> backoffconfig=TestSet.getConfigurationStrings(backoffNode.getElementsByTagName("param"));
ArrayList<String> tieconfig=TestSet.getConfigurationStrings(tieNode.getElementsByTagName("param"));
//Retrieving the window filters
ArrayList<WindowFilter> filters=TestSet.instantiateFilters(wsdNode, dict);
ArrayList<WindowFilter> bfilters=TestSet.instantiateFilters(backoffNode, dict);
ArrayList<WindowFilter> tfilters=TestSet.instantiateFilters(tieNode, dict);
ArrayList<SkipFilter> sfilters=TestSet.instantiateSkipFilters(wsdNode, ins,dict);
ArrayList<SkipFilter> bsfilters=TestSet.instantiateSkipFilters(backoffNode,ins, dict);
ArrayList<SkipFilter> tsfilters=TestSet.instantiateSkipFilters(tieNode, ins,dict);
ArrayList<ArrayList<WindowFilter>> combos=generateCombos(wsdNode,filters);
int bsize;
int tsize;
int csize;
if(combos.size()==0)
csize=1;
else
csize=combos.size();
if(backoffconfig.size()==0)
bsize=1;
else
bsize=backoffconfig.size();
if(tieconfig.size()==0)
tsize=1;
else
tsize=tieconfig.size();
for(String wsdconf:wsdconfig)
for(int c=0;c<csize;c++)
for(int b=0;b<bsize;b++)
for(int t=0;t<tsize;t++)
{
WSDAlgorithm wsd=TestSet.instantiateAlgorithm(wsdNode,dict);
WSDAlgorithm backoff=TestSet.instantiateAlgorithm(backoffNode,dict);
WSDAlgorithm tie=TestSet.instantiateAlgorithm(tieNode,dict);
if(wsd==null)
throw new Exception("You must specify a valid wsd class!");
Test testx=new Test(inputs,wsd,backoff,tie,testsetsnames.get(tindex),sensefilters.get(tindex),sources,dict);
ts.add(testx);
ArrayList<KeyString> configs=testx.getConfigurations();
configs.add(new KeyString("wsd",wsdconf));
wsd.setSkipFilters(sfilters);
if(backoff!=null)
backoff.setSkipFilters(bsfilters);
if(tie!=null)
tie.setSkipFilters(tsfilters);
if(combos.size()>0)
wsd.setWindowFilters(combos.get(c));
else
wsd.setWindowFilters(new ArrayList<WindowFilter>(1));
if(backoff!=null)
backoff.setWindowFilters(bfilters);
if(tie!=null)
tie.setWindowFilters(tfilters);
if(b<backoffconfig.size())
configs.add(new KeyString("backoff",backoffconfig.get(b)));
if(t<tieconfig.size())
configs.add(new KeyString("tie",tieconfig.get(t)));
}
}
tests.add(ts);
tindex++;
}
int j=0;
int r=0;
FileWriter fout=new FileWriter("./output.txt");
BufferedWriter out=new BufferedWriter(fout);
int xx=0;
for(String prefix:prefixes)
{
File file=new File(src+"/"+prefix);
file.mkdirs();
Rscripts.add(src+"/"+prefix+".R");
for(int i=0;i<tests.get(xx).size();i++)
{
filenames.add(src+"/"+prefix+"/Test"+String.valueOf(i+1)+"_"+tests.get(xx).get(i).getAlgorithm().getName()+".xls");
}
xx++;
}
for(ArrayList<Test> testX:tests)
{
for(Test test:testX)
{
File ex=new File(filenames.get(j));
if(!ex.exists())
test.run(filenames.get(j),xls.getAttribute("summary"),xls.getAttribute("detail"),out);
j++;
}
if(Boolean.parseBoolean(xls.getAttribute("stat")))
XLSWriter.generateRScript(new File(Rscripts.get(r)), testX);
testX.clear();
r++;
out.write("====================================\n\n\n");
}
out.close();
fout.close();
}
/**
* Method for instantiating WindowFilters as specified in the XML file.
* @param algo XML node containing windowfilter nodes.
* @param dict Base dictionary.
* @return An ArrayList with the corresponding WindowFilter objects if any.
* @throws Exception
*/
static ArrayList<WindowFilter> instantiateFilters(Element algo,DataBroker dict)throws Exception
{
ArrayList<WindowFilter> filters=new ArrayList<WindowFilter>();
NodeList conditions=algo.getElementsByTagName("windowfilter");
for(int j=0;j<conditions.getLength();j++)
{
Element filternode=(Element)conditions.item(j);
if(!filternode.getAttribute("class").equals("none"))
{
WindowFilter filter=(WindowFilter)Class.forName(filternode.getAttribute("class")).getConstructor(String.class).newInstance(filternode.getAttribute("config"));
filter.setDict(dict);
filters.add(filter);
}
}
return filters;
}
/**
* Method for instantiating SkipFilter objects as specified in the XML file.
* @param algo XML node containing the skipfilter nodes.
* @param ins List containing the Input that are going to be modified by these SkipFilter objects.
* @param dict Base dictionary.
* @return An ArrayList with the corresponding SkipFilter objects if any.
* @throws Exception
*/
static ArrayList<SkipFilter> instantiateSkipFilters(Element algo, ArrayList<Input> ins,DataBroker dict)throws Exception
{
ArrayList<SkipFilter> filters=new ArrayList<SkipFilter>();
NodeList conditions=algo.getElementsByTagName("skipfilter");
for(int j=0;j<conditions.getLength();j++)
{
Element filternode=(Element)conditions.item(j);
if(!filternode.getAttribute("class").equals("none"))
{
SkipFilter filter=(SkipFilter)Class.forName(filternode.getAttribute("class")).newInstance();
filter.setDict(dict);
filter.setParameters(filternode.getAttribute("config"));
filters.add(filter);
}
//TODO
/*if(filternode.getAttribute("class").contains("SkipNotOSD")&&dict.isWeb())
{
SkipNotOSD.corpus=new WikiCorpus(ins,dict.getName(),((DataBroker)dict).getSource());
}
if(filternode.getAttribute("class").contains("SkipOSD")&&dict.isWeb())
{
SkipOSD.corpus=new WikiCorpus(ins,dict.getName(),((DataBroker)dict).getSource());
}*/
}
return filters;
}
/**
* Method for generating all the possible combinations of WindowFilter objects.
* @param algo XML algorithm node containing the windowfilter nodes.
* @param filters The instantiate filters for generating the combinations.
* @return A List with all the possible combinations of WindowFilter
* objects when filterCombination attribute is not equals to "single".
* An encapsulated version of the origininal list when filterCombination attribute is equals to "single".
*/
static ArrayList<ArrayList<WindowFilter>> generateCombos(Element algo,ArrayList<WindowFilter> filters)
{
//Create multiple filter combos
ArrayList<ArrayList<WindowFilter>> combos=new ArrayList<ArrayList<WindowFilter>>();
String combi=algo.getAttribute("filterCombination");
if(combi.equals("single")||combi.equals(""))
{
if(filters.size()>0)
combos.add(filters);
}
else
{
int deep=filters.size();
while(deep>0)
{
//Create combinations
int indexes[]=new int[deep];
for(int x=0;x<deep;x++)
{
indexes[x]=x;
}
while(indexes[0]<filters.size())
{
ArrayList<WindowFilter> combo=new ArrayList<WindowFilter>();
boolean ban=true;
while(ban&&indexes[0]<filters.size())
{
ban=false;
for(int aux=deep-1;aux>0;aux--)
{
if(indexes[aux]>=filters.size())
{//Readjust indexes
ban=true;
indexes[aux]=0;
indexes[aux-1]++;
}
}
//Correct indexes
for(int aux=1;aux<deep;aux++)
if(indexes[aux]==0)
{
indexes[aux]=indexes[aux-1]+1;
ban=true;
}
}
if(indexes[0]>=filters.size())
break;
//Create combo
for(int index:indexes)
{
combo.add(filters.get(index));
}
indexes[deep-1]++;
combos.add(combo);
}
deep--;
}
combos.add(new ArrayList<WindowFilter>());
}
return combos;
}
/**
* Creates a list containing valid parameter strings extracted from a generalized parameter specification such as:
* "P1:[value1,value2]" and "P1:[1.0,1.1,...,1.4]" .
* @param nodes Nodes containing parameter configurations.
* @returna A list containing valid parameter strings in its base form. E.g.
* "P1:[value1,value2]" is stored as a list containing ("P1:value1;", "P1:value2;"),
* "P1:[1.0,1.1,...,1.4]" is stored as a list containing ("P1:1.0;", "P1:1.1;", "P1:1.2;", "P1:1.3;", "P1:1.4;").
*/
private static ArrayList<String> getConfigurationStrings(
NodeList nodes) {
ArrayList<KeyArray> configs=new ArrayList<KeyArray>();
for(int i=0;i<nodes.getLength();i++)
{
Element param=(Element)nodes.item(i);
String config=param.getAttribute("config");
if(!config.contains("["))//simple config string
{
String []tokens=config.split(":");
KeyArray ka=new KeyArray(tokens[0],tokens[1]);
if(configs.contains(ka))
{
configs.get(configs.indexOf(ka)).getArray().add(tokens[1]);
}
else
{
configs.add(ka);
}
}
else
{
//Process the config that contains a list
String tokens[]=config.replace("[", "").replace("]", "").split(":");
String head=tokens[0];
String tail=tokens[1];
String values[]=tail.split(",");
boolean isdouble=values[0].contains(".");
for(int x=0;x<values.length;x++)
{
String value=values[x];
if(value.equals("..."))
{
if(isdouble)
{
double base=Double.parseDouble(values[x-2]);
double next=Double.parseDouble(values[x-1]);
double inc=next-base;
double current=next+inc;
double end=Double.parseDouble(values[x+1]);
while(current<end)
{
KeyArray ka=new KeyArray(head,String.valueOf(current));
if(configs.contains(ka))
{
configs.get(configs.indexOf(ka)).getArray().add(String.valueOf(current));
}
else
{
configs.add(ka);
}
current+=inc;
}
}
else
{
int base=Integer.parseInt(values[x-2]);
int next=Integer.parseInt(values[x-1]);
int inc=next-base;
int current=next+inc;
int end=Integer.parseInt(values[x+1]);
while(current<end)
{
KeyArray ka=new KeyArray(head,String.valueOf(current));
if(configs.contains(ka))
{
configs.get(configs.indexOf(ka)).getArray().add(String.valueOf(current));
}
else
{
configs.add(ka);
}
current+=inc;
}
}
}
else
{
KeyArray ka=new KeyArray(head,values[x]);
if(configs.contains(ka))
{
configs.get(configs.indexOf(ka)).getArray().add(values[x]);
}
else
{
configs.add(ka);
}
}
}
}
}
int []indexes=new int[configs.size()];
for(int i=0;i<configs.size();i++)
{
indexes[i]=0;
}
ArrayList<String> configStrings=new ArrayList<String>();
while(configs.size()>0&&indexes[0]<configs.get(0).getArray().size())
{
//generate the string
String config="";
for(int i=0;i<configs.size();i++)
{
config+=configs.get(i).getKey()+":"+configs.get(i).getArray().get(indexes[i])+";";
}
configStrings.add(config);
//increase pointer
int current=configs.size()-1;
indexes[current]++;
while(indexes[current]>=configs.get(current).getArray().size())//index overflow change the previous level
{
if(current<=0)
break;
indexes[current]=0;
current--;
indexes[current]++;
}
}
return configStrings;
}
/**
* Instantiates a WSDAlgorithm as specified in an XML node.
* @param item XML node containing the WSDAlgorithm specification.
* @param dict Base dictionary.
* @return The newly created WSDAlgorithm.
* @throws Exception
*/
private static WSDAlgorithm instantiateAlgorithm(Element item,
DataBroker dict)throws Exception {
WSDAlgorithm algo=null;
if(!item.getAttribute("class").equals("none"))
{
algo=(WSDAlgorithm)Class.forName(item.getAttribute("class")).newInstance();
algo.setDict(dict);
}
return algo;
}
}