Package statechum.analysis.learning

Source Code of statechum.analysis.learning.RPNIBlueAmberFringeLearner

package statechum.analysis.learning;

import java.awt.Frame;
import java.util.Collection;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
import java.util.Stack;

import statechum.Configuration;
import statechum.JUConstants;
import statechum.Pair;
import statechum.Configuration.QuestionGeneratorKind;
import statechum.DeterministicDirectedSparseGraph.CmpVertex;
import statechum.analysis.learning.observers.Learner;
import statechum.analysis.learning.observers.Learner.RestartLearningEnum;
import statechum.analysis.learning.rpnicore.ComputeQuestions;
import statechum.analysis.learning.rpnicore.LearnerGraph;
import statechum.analysis.learning.rpnicore.MergeStates;
import statechum.model.testset.PTASequenceEngine;
import statechum.model.testset.PTASequenceSet;
import edu.uci.ics.jung.graph.impl.DirectedSparseGraph;

public class RPNIBlueAmberFringeLearner extends RPNIBlueFringeLearner {
 
  public RPNIBlueAmberFringeLearner(Frame parent, Configuration c) {
    super(parent,c);
    scoreComputer = new LearnerGraph(c);
  }
 
  protected LearnerGraph scoreComputer = null;

 
  /** The size of the initial plus/minus sets. */
  protected int origPlusSize, origMinusSize;
 
  @Override
  public void init(Collection<List<String>> plus, Collection<List<String>> minus)
  {
    scoreComputer.initPTA();
    scoreComputer.paths.augmentPTA(minus, false);
    scoreComputer.paths.augmentPTA(plus, true);
    origMinusSize = plus.size();origMinusSize = minus.size();
  }
 
  @Override
  public void init(PTASequenceEngine en, int plus, int minus)
  {
    scoreComputer.initPTA();
    scoreComputer.paths.augmentPTA(en);

    origMinusSize = plus;origMinusSize = minus;
  }

  public String DifferenceBetweenPairOfSets(String prefix, Collection<List<String>> seqOrig,Collection<List<String>> seqNew)
  {
    Set<List<String>> newInQS = new HashSet<List<String>>();newInQS.addAll(seqNew);newInQS.removeAll(seqOrig);
    Set<List<String>> newInOrig = new HashSet<List<String>>();newInOrig.addAll(seqOrig);newInOrig.removeAll(seqNew);
    return prefix+": new in QS:\n"+newInQS+"\n"+prefix+": new In Orig:\n"+newInOrig;
  }
 
  int plusSize = 0, minusSize = 0;
 
  /** Does
   * <pre>
   * MergeStates.mergeAndDeterminize_general(scoreComputer, pair);
   * </pre>
   * but additionally checks for consistency.
   */
  protected LearnerGraph getMergedGraph(StatePair pair, LearnerGraph newPTA)
  {
    int nonAmberA = 0;
    if (scoreComputer.config.isConsistencyCheckMode()) nonAmberA = newPTA.getStateNumber()-newPTA.getAmberStateNumber();
    LearnerGraph tempNew = MergeStates.mergeAndDeterminize_general(scoreComputer, pair);
    if (scoreComputer.config.isConsistencyCheckMode()) assert (newPTA.getStateNumber()-newPTA.getAmberStateNumber()) == nonAmberA;

    if (scoreComputer.config.isConsistencyCheckMode())
    {
      LearnerGraph tempOrig = MergeStates.mergeAndDeterminize(scoreComputer, pair);
      MergeStates.verifySameMergeResults(tempOrig, tempNew);
    }
    return tempNew;
  }
 
  /** Returns a collection of questions, but also checks them for consistency.
   *
   * @param tempNew the graph after merge.
   * @param pair pair of states to consider.
   * @return questions to ask.
   */
  protected List<List<String>> getQuestions(LearnerGraph tempNew, PairScore pair)
  {
    List<List<String>> questions = ComputeQuestions.computeQS(pair, scoreComputer,tempNew);
    if (scoreComputer.config.isConsistencyCheckMode())
    {// checking that all the old questions are included in the new ones
      assert scoreComputer.config.getQuestionGenerator() == QuestionGeneratorKind.CONVENTIONAL;
      assert scoreComputer.config.getQuestionPathUnionLimit() < 0;
     
      Collection<List<String>> questionsOrigA = ComputeQuestions.computeQS_orig(pair, scoreComputer,MergeStates.mergeAndDeterminize(scoreComputer, pair));
      CmpVertex Rnew = tempNew.getStateLearnt();
      assert Rnew == tempNew.getVertex(scoreComputer.wmethod.computeShortPathsToAllStates().get(pair.getR()));
      Collection<List<String>> questionsOrigB = ComputeQuestions.computeQS_orig(new StatePair(Rnew,Rnew), scoreComputer,tempNew);
      PTASequenceSet newQuestions =new PTASequenceSet();newQuestions.addAll(questions);
      assert newQuestions.containsAll(questionsOrigA);
      assert newQuestions.containsAll(questionsOrigB);
    }
   
    return questions;
  }

  @Override
  public Learner getLearner()
  {
    return thisLearner;
  }
 
  protected final Learner thisLearner = new Learner()
  {

    public void AugmentPTA(LearnerGraph pta,
        @SuppressWarnings("unused") RestartLearningEnum ptaKind,
        List<String> sequence, boolean accepted, JUConstants newColour)
    {
      pta.paths.augmentPTA(sequence, accepted, newColour);
    }

    public Pair<Integer, String> CheckWithEndUser(LearnerGraph graph,
        List<String> question, Object[] options) {
      return RPNIBlueAmberFringeLearner.this.checkWithEndUser(graph, question, options);
    }

    public Stack<PairScore> ChooseStatePairs(LearnerGraph graph) {
      return graph.pairscores.chooseStatePairs();
    }

    public List<List<String>> ComputeQuestions(PairScore pair,
        @SuppressWarnings("unused"LearnerGraph original, LearnerGraph temp)
    {
      return getQuestions(temp, pair);
    }

    public LearnerGraph MergeAndDeterminize(LearnerGraph original, StatePair pair)
    {
      return getMergedGraph(pair,original);
    }

    public void Restart(@SuppressWarnings("unused"RestartLearningEnum mode)
    {
      // does nothing
    }

    public String getResult() {
      return null;
    }

    public LearnerGraph init(Collection<List<String>> plus,  Collection<List<String>> minus)
    {
      RPNIBlueAmberFringeLearner.this.init(plus, minus);
      return RPNIBlueAmberFringeLearner.this.scoreComputer;
    }

    public LearnerGraph init(PTASequenceEngine engine, int plus, int minus) {
      RPNIBlueAmberFringeLearner.this.init(engine, plus, minus);
      return RPNIBlueAmberFringeLearner.this.scoreComputer;
    }

    public LearnerGraph learnMachine() {
      return new LearnerGraph(RPNIBlueAmberFringeLearner.this.learnMachine(),Configuration.getDefaultConfiguration());
    }

    public LearnerGraph learnMachine(PTASequenceEngine engine, int plus, int minus)
    {
      topLearner.init(engine, plus, minus);
      return new LearnerGraph(RPNIBlueAmberFringeLearner.this.learnMachine(),scoreComputer.config);
    }

    public LearnerGraph learnMachine(Collection<List<String>> plus, Collection<List<String>> minus)
    {
      topLearner.init(plus, minus);
      return new LearnerGraph(RPNIBlueAmberFringeLearner.this.learnMachine(),scoreComputer.config);
    }

   
    public void setTopLevelListener(Learner top) {
      topLearner = top;
    }
   
  };
 
  Learner topLearner = thisLearner;
 
  @Override
  public DirectedSparseGraph learnMachine() {
    setAutoOracle();
    LearnerGraph initialPTA = scoreComputer;// no need to clone - this is the job of mergeAndDeterminize anyway
    setChanged();
    initialPTA.setName("merge_debug"+0);
    updateGraph(initialPTA);
   
    Stack<PairScore> possibleMerges = topLearner.ChooseStatePairs(scoreComputer);
    plusSize = origPlusSize;minusSize = origMinusSize;
    int iterations = 0, currentNonAmber = initialPTA.getStateNumber()-initialPTA.getAmberStateNumber();
    counterRestarted = 0;
    while(!possibleMerges.isEmpty())
    {
      iterations++;
      //populateScores(possibleMerges,possibleMergeScoreDistribution);
      PairScore pair = possibleMerges.pop();
      LearnerGraph temp = topLearner.MergeAndDeterminize(initialPTA, pair);// FIXME: should be scoreComputer
      //System.out.println("considering "+pair+" non-amber: "+(newPTA.getStateNumber()-newPTA.getAmberStateNumber()));
      //Visualiser.updateFrame(scoreComputer.paths.getGraph(), temp.paths.getGraph());Visualiser.waitForKey();
      setChanged();temp.setName("merge_debug"+iterations);
      updateGraph(temp);
      Collection<List<String>> questions = new LinkedList<List<String>>();
      int score = pair.getScore();

      if(shouldAskQuestions(score))
        questions = topLearner.ComputeQuestions(pair, scoreComputer, temp);

      boolean restartLearning = false;// whether we need to rebuild a PTA and restart learning.
     
      //System.out.println(Thread.currentThread()+ " "+pair + " "+questions);
      Iterator<List<String>> questionIt = questions.iterator();
      while(questionIt.hasNext()){
        List<String> question = questionIt.next();
        boolean accepted = pair.getQ().isAccept();
        Pair<Integer,String> answer = topLearner.CheckWithEndUser(scoreComputer,question, new Object [] {"Test"});
        this.questionCounter++;
        if (answer.firstElem == USER_CANCELLED)
        {
          System.out.println("CANCELLED");
          return null;
        }
       
        CmpVertex tempVertex = temp.getVertex(question);
       
        if(answer.firstElem == USER_ACCEPTED)
        {
          //sPlus.add(question);
          topLearner.AugmentPTA(initialPTA, RestartLearningEnum.restartHARD, question, true,JUConstants.AMBER);
          //initialPTA.paths.augmentPTA(question, true,JUConstants.AMBER);
          ++plusSize;
          if (ans != null) System.out.println(howAnswerWasObtained+question.toString()+ " <yes>");
          if(!tempVertex.isAccept())
          {
            restartLearning = true;break;
          }
        }
        else
          if(answer.firstElem >= 0)
          {// The sequence has been rejected by a user
            assert answer.firstElem < question.size();
            LinkedList<String> subAnswer = new LinkedList<String>();subAnswer.addAll(question.subList(0, answer.firstElem+1));
            //sMinus.add(subAnswer);
            topLearner.AugmentPTA(initialPTA, RestartLearningEnum.restartHARD, subAnswer, false,JUConstants.AMBER);
            //initialPTA.paths.augmentPTA(subAnswer, false,JUConstants.AMBER);
            ++minusSize;
            // important: since vertex IDs are
            // only unique for each instance of ComputeStateScores, only once
            // instance should ever receive calls to augmentPTA
            if (ans != null) System.out.println(howAnswerWasObtained+question.toString()+ " <no> at position "+answer.firstElem+", element "+question.get(answer.firstElem));
            if( (answer.firstElem < question.size()-1) || tempVertex.isAccept())
            {
              assert accepted == true;
              restartLearning = true;break;
            }
          }
          else
            throw new IllegalArgumentException("unexpected user choice");
       
      }
     
      if (restartLearning)
      {// restart learning
        //ComputeStateScores expected = createAugmentedPTA(sPlus, sMinus);// KIRR: node labelling is done by createAugmentedPTA
        //System.out.println("restart at pair "+pair+", currently "+scoreComputer.getStateNumber()+" states, "+(scoreComputer.getStateNumber()-scoreComputer.getAmberStateNumber())+" non-amber");
        if (config.isSpeculativeQuestionAsking())
          if (speculativeGraphUpdate(possibleMerges, initialPTA))
            return null;
        scoreComputer = initialPTA;// no need to clone - this is the job of mergeAndDeterminize anyway
        scoreComputer.clearColoursButAmber();
        //System.out.println("finished with speculative update, currently "+scoreComputer.getStateNumber()+" states, "+(scoreComputer.getStateNumber()-scoreComputer.getAmberStateNumber())+" non-amber");
        iterations = 0;counterRestarted++;
        topLearner.Restart(RestartLearningEnum.restartHARD);
      }
      else
      {
        // At this point, scoreComputer may have been modified because it may point to
        // the original PTA which will be modified as a result of new sequences being added to it.
        // temp is different too, hence there is no way for me to compute compatibility score here.
        // This is hence computed inside the obtainPair method.
       
        // keep going with the existing model
        scoreComputer = temp;
        topLearner.Restart(RestartLearningEnum.restartNONE);
      }
     
      possibleMerges = topLearner.ChooseStatePairs(scoreComputer);
    }
    assert currentNonAmber == initialPTA.getStateNumber()-initialPTA.getAmberStateNumber();
    DirectedSparseGraph result = scoreComputer.paths.getGraph();
    if(config.getDebugMode())
      updateGraph(scoreComputer);
    return result;
  }
 
  /** We might be doing a restart, but it never hurts to go through the existing
   * collection of vertices to merge and see if we can update the graph.
   * 
   * @return true if question answering has been cancelled by a user.
   */
  boolean speculativeGraphUpdate(Stack<PairScore> possibleMerges, LearnerGraph originalPTA)
  {
    while(!possibleMerges.isEmpty())
    {
      PairScore pair = possibleMerges.pop();
      int score = pair.getScore();

      if(shouldAskQuestions(score))
      {
        LearnerGraph tempNew = null;
        try
        {
          tempNew = topLearner.MergeAndDeterminize(originalPTA, pair);
        }
        catch(IllegalArgumentException ex)
        {// ignore - tempNew is null anyway         
        }
       
        if (tempNew != null) // merge successful - it would fail if our updates to newPTA have modified scoreComputer (the two are often the same graph)
        {         
          for(List<String> question:getQuestions(tempNew, pair))
          {
            Pair<Integer,String> answer = topLearner.CheckWithEndUser(scoreComputer,question, new Object [] {"Test"});
            this.questionCounter++;
            if (answer.firstElem == USER_CANCELLED)
            {
              System.out.println("CANCELLED");
              return true;
            }
           
            if(answer.firstElem == USER_ACCEPTED)
            {
              topLearner.AugmentPTA(originalPTA, RestartLearningEnum.restartHARD, question, true,JUConstants.AMBER);
              //originalPTA.paths.augmentPTA(question, true,JUConstants.AMBER);
              ++plusSize;
            }
            else
              if(answer.firstElem >= 0)
              {// The sequence has been rejected by a user
                assert answer.firstElem < question.size();
                LinkedList<String> subAnswer = new LinkedList<String>();subAnswer.addAll(question.subList(0, answer.firstElem+1));
                topLearner.AugmentPTA(originalPTA, RestartLearningEnum.restartHARD, subAnswer, false,JUConstants.AMBER);
                //originalPTA.paths.augmentPTA(subAnswer, false,JUConstants.AMBER);
                ++minusSize;
              }
          }
        }
      }
    }
   
    return false;
  }
}
TOP

Related Classes of statechum.analysis.learning.RPNIBlueAmberFringeLearner

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.