/*Copyright (c) 2006, 2007, 2008 Neil Walkinshaw and Kirill Bogdanov
This file is part of StateChum
StateChum is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
StateChum is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with StateChum. If not, see <http://www.gnu.org/licenses/>.
*/
package statechum.analysis.learning;
import java.awt.Frame;
import java.io.StringWriter;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.Stack;
import java.util.TreeMap;
import java.util.Map.Entry;
import java.util.concurrent.atomic.AtomicInteger;
import statechum.Configuration;
import statechum.JUConstants;
import statechum.Pair;
import statechum.Configuration.QuestionGeneratorKind;
import statechum.DeterministicDirectedSparseGraph.CmpVertex;
import statechum.analysis.learning.observers.Learner;
import statechum.analysis.learning.observers.Learner.RestartLearningEnum;
import statechum.analysis.learning.rpnicore.ComputeQuestions;
import statechum.analysis.learning.rpnicore.LearnerGraph;
import statechum.analysis.learning.rpnicore.MergeStates;
import statechum.analysis.learning.rpnicore.WMethod;
import statechum.model.testset.PTASequenceEngine;
import statechum.model.testset.PTASequenceSet;
import edu.uci.ics.jung.graph.impl.DirectedSparseGraph;
import edu.uci.ics.jung.utils.UserData;
public class RPNIBlueFringeLearnerTestComponentOpt extends RPNIBlueFringeLearner {
public RPNIBlueFringeLearnerTestComponentOpt(Frame parent, Configuration c) {
super(parent,c);
scoreComputer = new LearnerGraph(c);
}
protected void update(StatePair pair)
{
pair.getQ().setHighlight(true);
pair.getR().setHighlight(true);// since this copy of the graph will really not be used, changes to it are immaterial at this stage
}
protected LearnerGraph scoreComputer = null;
protected int counterAccepted =0, counterRejected =0, counterEmptyQuestions = 0;
/** Takes the candidates for merging and computes the number of times different scores are encountered. */
public static void populateScores(Collection<PairScore> data, Map<Integer,AtomicInteger> histogram)
{
for(PairScore pair:data)
{
int pairScore = pair.getScore();
AtomicInteger count = histogram.get(pairScore);
if (count == null)
{
count = new AtomicInteger();histogram.put(pairScore,count);
}
count.incrementAndGet();
}
}
/** Takes the candidates for merging and computes the number of times different scores (increments of 10) are encountered. */
public static void populateHistogram(Collection<PairScore> data, Map<Integer,AtomicInteger> histogram)
{
for(PairScore pair:data)
{
int pairScore = pair.getScore()>= 200? pair.getScore()-pair.getScore() % 100: pair.getScore()>=10? pair.getScore()-pair.getScore()%10: pair.getScore()>0?1:0;
AtomicInteger count = histogram.get(pairScore);
if (count == null)
{
count = new AtomicInteger();histogram.put(pairScore,count);
}
count.incrementAndGet();
}
}
public static String HistogramToString(Map<Integer,AtomicInteger> histogram, String Name)
{
final String FS=",";
String result="\n"+Name;
Map<Integer, AtomicInteger> tmp = new TreeMap<Integer,AtomicInteger>();
tmp.putAll(histogram);
for(Entry<Integer,AtomicInteger> sc:tmp.entrySet())
result = result+FS+sc.getValue();
result=result+"\n"+Name;
for(Entry<Integer,AtomicInteger> sc:tmp.entrySet())
result = result+FS+sc.getKey();
return result+"\n";
}
public static String HistogramToSeries(Map<Integer,AtomicInteger> histogram, String Name)
{
final String FS=",";
String result="\n"+Name;
Map<Integer, AtomicInteger> tmp = new TreeMap<Integer,AtomicInteger>();
tmp.putAll(histogram);
int limit = 0;
for(Entry<Integer,AtomicInteger> sc:tmp.entrySet()){
limit = sc.getValue().get();
for(int i = 0;i<limit;i++){
result = result+FS+sc.getKey();
}
}
return result+"\n";
}
public static String pairScoresAndIterations(Map<PairScore,Integer> map, String name){
final String FS=",";
String result="\n"+name+"-score"+FS;
for(PairScore score:map.keySet())
result=result+score.getScore()+FS;
result = result+"\n"+name+"-iteration"+FS;
for(Integer i:map.values())
result = result+i+FS;
return result;
}
/** The size of the initial plus/minus sets. */
protected int origPlusSize, origMinusSize;
@Override
public void init(Collection<List<String>> plus, Collection<List<String>> minus)
{
scoreComputer.initPTA();
scoreComputer.paths.augmentPTA(minus, false);
scoreComputer.paths.augmentPTA(plus, true);
origMinusSize = plus.size();origMinusSize = minus.size();
}
@Override
public void init(PTASequenceEngine en, int plusSize, int minusSize)
{
scoreComputer.initPTA();
scoreComputer.paths.augmentPTA(en);
origMinusSize = plusSize;origMinusSize = minusSize;
}
@Override
public void loadPTA(String name)
{
scoreComputer = LearnerGraph.loadGraph(name, Configuration.getDefaultConfiguration());
}
public String DifferenceBetweenPairOfSets(String prefix, Collection<List<String>> seqOrig,Collection<List<String>> seqNew)
{
Set<List<String>> newInQS = new HashSet<List<String>>();newInQS.addAll(seqNew);newInQS.removeAll(seqOrig);
Set<List<String>> newInOrig = new HashSet<List<String>>();newInOrig.addAll(seqOrig);newInOrig.removeAll(seqNew);
return prefix+": new in QS:\n"+newInQS+"\n"+prefix+": new In Orig:\n"+newInOrig;
}
protected void debugAction(LearnerGraph lg, @SuppressWarnings("unused") int iterations){
if(!config.getDebugMode())
return;
updateGraph(lg);
}
public Learner getLearner()
{
return thisLearner;
}
protected final Learner thisLearner = new Learner()
{
public void AugmentPTA(LearnerGraph pta, @SuppressWarnings("unused") RestartLearningEnum ptaKind,
List<String> sequence, boolean accepted, JUConstants newColour) {
pta.paths.augmentPTA(sequence, accepted, newColour);
}
public Pair<Integer, String> CheckWithEndUser(LearnerGraph graph,
List<String> question, Object[] options) {
return RPNIBlueFringeLearnerTestComponentOpt.this.checkWithEndUser(graph, question, options);
}
public Stack<PairScore> ChooseStatePairs(LearnerGraph graph) {
return graph.pairscores.chooseStatePairs();
}
public List<List<String>> ComputeQuestions(PairScore pair,
LearnerGraph original, LearnerGraph temp) {
return ComputeQuestions.computeQS(pair, original,temp);
}
public LearnerGraph MergeAndDeterminize(LearnerGraph original, StatePair pair) {
return MergeStates.mergeAndDeterminize_general(original, pair);
}
public void Restart(@SuppressWarnings("unused") RestartLearningEnum mode) {
// does nothing
}
public String getResult() {
return null;
}
public LearnerGraph init(Collection<List<String>> plus, Collection<List<String>> minus)
{
RPNIBlueFringeLearnerTestComponentOpt.this.init(plus, minus);
return RPNIBlueFringeLearnerTestComponentOpt.this.scoreComputer;
}
public LearnerGraph init(PTASequenceEngine engine, int plusSize, int minusSize) {
RPNIBlueFringeLearnerTestComponentOpt.this.init(engine, plusSize, minusSize);
return RPNIBlueFringeLearnerTestComponentOpt.this.scoreComputer;
}
public LearnerGraph learnMachine() {
return new LearnerGraph(RPNIBlueFringeLearnerTestComponentOpt.this.learnMachine(),Configuration.getDefaultConfiguration());
}
public LearnerGraph learnMachine(PTASequenceEngine engine, int plusSize, int minusSize)
{
topLearner.init(engine, plusSize, minusSize);
return new LearnerGraph(RPNIBlueFringeLearnerTestComponentOpt.this.learnMachine(),scoreComputer.config);
}
public LearnerGraph learnMachine(Collection<List<String>> plus, Collection<List<String>> minus)
{
topLearner.init(plus, minus);
return new LearnerGraph(RPNIBlueFringeLearnerTestComponentOpt.this.learnMachine(),scoreComputer.config);
}
public void setTopLevelListener(Learner top) {
topLearner = top;
}
};
Learner topLearner = thisLearner;
/* Note: in order to get the same results from learning as in modified Dec 2007 version
* on the appropriate branch, the following has to be done:
* 1. DeterministicDirectedSparseGraph.VertexID.comparisonKind = DeterministicDirectedSparseGraph.VertexID.ComparisonKind.COMPARISON_LEXICOGRAPHIC_ORIG;
* 2. load the initial PTA from _mt files (dumped by Dec 2007 version).
* 3. merge using tempOrig = MergeStates.mergeAndDeterminize(scoreComputer, pair);
* 4. generate questions using questions = ArrayOperations.sort(ComputeQuestions.computeQS_origReduced(pair, scoreComputer,tempOrig));
*/
@Override
public DirectedSparseGraph learnMachine() {
setAutoOracle();
Map<Integer, AtomicInteger> whichScoresWereUsedForMerging = new HashMap<Integer,AtomicInteger>(),
restartScoreDistribution = new HashMap<Integer,AtomicInteger>();
Map<PairScore, Integer> scoresToIterations = new HashMap<PairScore, Integer>();
Map<PairScore, Integer> restartsToIterations = new HashMap<PairScore, Integer>();
LearnerGraph newPTA = scoreComputer;// no need to clone - this is the job of mergeAndDeterminize anyway
StringWriter report = new StringWriter();
counterAccepted =0;counterRejected =0;counterRestarted = 0;counterEmptyQuestions = 0;report.write("\n[ PTA: "+scoreComputer.paths.getStatistics(false)+" ] ");
setChanged();
newPTA.setName("merge_debug"+0);
updateGraph(newPTA);
Stack<PairScore> possibleMerges = topLearner.ChooseStatePairs(scoreComputer);
int plusSize = origPlusSize, minusSize = origMinusSize, iterations = 0;
final int restartOfInterest = -21;
while(!possibleMerges.isEmpty())
{
iterations++;
//populateScores(possibleMerges,possibleMergeScoreDistribution);
PairScore pair = possibleMerges.pop();
if (counterRestarted == restartOfInterest) System.out.println("merging "+pair);
LearnerGraph tempOrig= null;
LearnerGraph tempNew = null;
//tempOrig = MergeStates.mergeAndDeterminize(scoreComputer, pair);
tempNew = topLearner.MergeAndDeterminize(scoreComputer, pair);
LearnerGraph temp=tempNew;
if (scoreComputer.config.isConsistencyCheckMode())
{
tempOrig = MergeStates.mergeAndDeterminize(scoreComputer, pair);
WMethod.checkM(tempNew, tempOrig);
MergeStates.verifySameMergeResults(tempOrig, tempNew);
}
setChanged();temp.setName("merge_debug"+iterations);
debugAction(temp, iterations);
Collection<List<String>> questions = new LinkedList<List<String>>();
int score = pair.getScore();
if(shouldAskQuestions(score))
{
//questions = ArrayOperations.sort(ComputeQuestions.computeQS_origReduced(pair, scoreComputer,tempOrig));
questions = topLearner.ComputeQuestions(pair, scoreComputer, tempNew);
if (scoreComputer.config.isConsistencyCheckMode())
{// checking that all the old questions are included in the new ones
assert scoreComputer.config.getQuestionGenerator() == QuestionGeneratorKind.CONVENTIONAL;
assert scoreComputer.config.getQuestionPathUnionLimit() < 0;
Collection<List<String>> questionsOrigA = ComputeQuestions.computeQS_orig(pair, scoreComputer,tempOrig);
//CmpVertex Rnew = tempNew.getVertex(scoreComputer.wmethod.computeShortPathsToAllStates().get(pair.getR()));
CmpVertex Rnew = tempNew.getStateLearnt();
assert Rnew == tempNew.getVertex(scoreComputer.wmethod.computeShortPathsToAllStates().get(pair.getR()));
Collection<List<String>> questionsOrigB = ComputeQuestions.computeQS_orig(new StatePair(Rnew,Rnew), scoreComputer,tempNew);
PTASequenceSet newQuestions =new PTASequenceSet();newQuestions.addAll(questions);
assert newQuestions.containsAll(questionsOrigA);
assert newQuestions.containsAll(questionsOrigB);
}
if (questions.isEmpty())
++counterEmptyQuestions;
}
boolean restartLearning = false;// whether we need to rebuild a PTA and restart learning.
//System.out.println(Thread.currentThread()+ " "+pair + " "+questions);
Iterator<List<String>> questionIt = questions.iterator();
while(questionIt.hasNext()){
List<String> question = questionIt.next();
boolean accepted = pair.getQ().isAccept();
Pair<Integer,String> answer = topLearner.CheckWithEndUser(scoreComputer, question, null);
this.questionCounter++;
if (answer.firstElem == USER_CANCELLED)
{
System.out.println("CANCELLED");
return null;
}
CmpVertex tempVertex = temp.getVertex(question);
if(answer.firstElem == USER_ACCEPTED)
{
++counterAccepted;
//sPlus.add(question);
topLearner.AugmentPTA(newPTA, RestartLearningEnum.restartHARD, question, true, null);
//newPTA.paths.augmentPTA(question, true, null);
++plusSize;
if (ans != null) System.out.println(howAnswerWasObtained+question.toString()+ " <yes>");
if (counterRestarted == restartOfInterest) System.out.println(question.toString()+ " <yes>");
if(!tempVertex.isAccept())
{
restartLearning = true;break;
}
}
else
if(answer.firstElem >= 0)
{// The sequence has been rejected by a user
assert answer.firstElem < question.size();
++counterRejected;
LinkedList<String> subAnswer = new LinkedList<String>();subAnswer.addAll(question.subList(0, answer.firstElem+1));
//sMinus.add(subAnswer);
topLearner.AugmentPTA(newPTA, RestartLearningEnum.restartHARD, subAnswer, false, null);
//newPTA.paths.augmentPTA(subAnswer, false, null);
++minusSize ;// important: since vertex IDs are
// only unique for each instance of ComputeStateScores, only once
// instance should ever receive calls to augmentPTA
if (ans != null) System.out.println(howAnswerWasObtained+question.toString()+ " <no> at position "+answer.firstElem+", element "+question.get(answer.firstElem));
if (counterRestarted == restartOfInterest) System.out.println(question.toString()+ " <no> at position "+answer.firstElem+", element "+question.get(answer.firstElem));
if( (answer.firstElem < question.size()-1) || tempVertex.isAccept())
{
assert accepted == true;
restartLearning = true;break;
}
}
else
throw new IllegalArgumentException("unexpected user choice");
}
if (restartLearning)
{// restart learning
//ComputeStateScores expected = createAugmentedPTA(sPlus, sMinus);// KIRR: node labelling is done by createAugmentedPTA
scoreComputer = newPTA;// no need to clone - this is the job of mergeAndDeterminize anyway
scoreComputer.clearColours();
++counterRestarted;
//System.out.println("restarts - "+counterRestarted+" questions: "+(counterAccepted+counterRejected)+" states in PTA: "+newPTA.getStateNumber());
//dumpPTA(scoreComputer,"/tmp/new_restart"+counterRestarted);
AtomicInteger count = restartScoreDistribution.get(pair.getScore());
if (count == null)
{
count = new AtomicInteger();restartScoreDistribution.put(pair.getScore(),count);
}
count.incrementAndGet();
restartsToIterations.put(pair, iterations);
iterations = 0;
topLearner.Restart(RestartLearningEnum.restartHARD);
//System.out.println("RESTART "+counterRestarted);
}
else
{
// At this point, scoreComputer may have been modified because it may point to
// the original PTA which will be modified as a result of new sequences being added to it.
// temp is different too, hence there is no way for me to compute compatibility score here.
// This is hence computed inside the obtainPair method.
// keep going with the existing model
scoreComputer = temp;
// now update the statistics
AtomicInteger count = whichScoresWereUsedForMerging.get(pair.getScore());
if (count == null)
{
count = new AtomicInteger();whichScoresWereUsedForMerging.put(pair.getScore(),count);
}
count.incrementAndGet();
scoresToIterations.put(pair, iterations);
topLearner.Restart(RestartLearningEnum.restartNONE);
}
possibleMerges = topLearner.ChooseStatePairs(scoreComputer);
//System.out.println(possibleMerges);
}
DirectedSparseGraph result = scoreComputer.paths.getGraph();result.addUserDatum(JUConstants.STATS, report.toString(), UserData.SHARED);
if(config.getDebugMode())
updateGraph(scoreComputer);
return result;
}
}