Package cc.mallet.fst

Examples of cc.mallet.fst.MEMM$TransitionIterator


    InstanceList training2 = new InstanceList (pipe2);
    training2.addThruPipe (new ArrayIterator (data0));
    InstanceList testing2 = new InstanceList (pipe2);
    testing2.addThruPipe (new ArrayIterator (data1));

    MEMM memm = new MEMM (pipe2, null);
    memm.addFullyConnectedStatesForLabels ();
    MEMMTrainer memmt = new MEMMTrainer (memm);
    TransducerEvaluator memmeval = new TokenAccuracyEvaluator (new InstanceList[] {training2, testing2}, new String[] {"Training2", "Testing2"});
    memmt.train (training2, 5);
    memmeval.evaluate(memmt);
View Full Code Here


    int numStates = 5;
    Alphabet inputAlphabet = new Alphabet();
    for (int i = 0; i < inputVocabSize; i++)
      inputAlphabet.lookupIndex("feature" + i);
    Alphabet outputAlphabet = new Alphabet();
    MEMM memm = new MEMM (inputAlphabet, outputAlphabet);
    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    memm.addFullyConnectedStates(stateNames);
    MEMMTrainer memmt = new MEMMTrainer (memm);
    MEMMTrainer.MEMMOptimizableByLabelLikelihood omemm = memmt.getOptimizableMEMM (new InstanceList(null));
    TestOptimizable.testGetSetParameters(omemm);
  }
View Full Code Here

//    String[] data = { TestMEMM.data[0], }; // TestMEMM.data[1], TestMEMM.data[2], TestMEMM.data[3], };
//    String[] data = { "ab" };
    training.addThruPipe (new ArrayIterator (data));

//    CRF4 memm = new CRF4 (p, null);
    MEMM memm = new MEMM (p, null);
    memm.addFullyConnectedStatesForLabels ();
    memm.addStartState();
    memm.setWeightsDimensionAsIn(training);
   
    MEMMTrainer memmt = new MEMMTrainer (memm);
//    memm.gatherTrainingSets (training); // ANNOYING: Need to set up per-instance training sets
    memmt.train (training, 1)// Set weights dimension, gathers training sets, etc.
View Full Code Here

  {
    Pipe p = makeSpacePredictionPipe ();
    InstanceList training = new InstanceList (p);
    training.addThruPipe (new ArrayIterator (data));

    MEMM memm = new MEMM (p, null);
    memm.addFullyConnectedStatesForLabels ();
    memm.addStartState();
    memm.setWeightsDimensionAsIn(training);
    MEMMTrainer memmt = new MEMMTrainer (memm);
    memmt.train (training, 10);

    MEMM memm2 = (MEMM) TestSerializable.cloneViaSerialization (memm);

    Optimizable.ByGradientValue mcrf1 = memmt.getOptimizableMEMM(training);
    double val1 = mcrf1.getValue ();
    Optimizable.ByGradientValue mcrf2 = memmt.getOptimizableMEMM(training);
    double val2 = mcrf2.getValue ();
View Full Code Here

    Alphabet outputAlphabet = new Alphabet();
    // Store the dictionary
    if (outputAlphabet == null) {
      System.err.println("Output dictionary null.");
    }
    MEMM crf = new MEMM(inputAlphabet, outputAlphabet);
    MEMMTrainer memmt = new MEMMTrainer (crf);

    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    MEMM saveCRF = crf;
    //inputAlphabet = (Feature.Alphabet) crf.getInputAlphabet();
    FeatureVectorSequence fvs = new FeatureVectorSequence(new FeatureVector[]{
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
View Full Code Here

    Pipe p2 = new TestMEMM2String();

    InstanceList instances = new InstanceList(p);
    instances.addThruPipe(new ArrayIterator(data));
    InstanceList[] lists = instances.split(new double[]{.5, .5});
    MEMM memm = new MEMM(p, p2);
    memm.addFullyConnectedStatesForLabels();
    memm.setWeightsDimensionAsIn(lists[0]);
   
    MEMMTrainer memmt = new MEMMTrainer (memm);
    if (testValueAndGradient) {
      Optimizable.ByGradientValue minable = memmt.getOptimizableMEMM(lists[0]);
      TestOptimizable.testValueAndGradient(minable);
    } else {
      System.out.println("Training Accuracy before training = " + memm.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy before training = " + memm.averageTokenAccuracy(lists[1]));
      System.out.println("Training...");
      memmt.train(lists[0], 1);
      System.out.println("Training Accuracy after training = " + memm.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy after training = " + memm.averageTokenAccuracy(lists[1]));
      System.out.println("Training results:");
      for (int i = 0; i < lists[0].size(); i++) {
        Instance inst = lists[0].get(i);
        Sequence input = (Sequence) inst.getData ();
        Sequence output = memm.transduce (input);
        System.out.println (output);
      }
      System.out.println ("Testing results:");
      for (int i = 0; i < lists[1].size(); i++) {
        Instance inst = lists[1].get(i);
        Sequence input = (Sequence) inst.getData ();
        Sequence output = memm.transduce (input);
        System.out.println (output);
      }
    }
  }
View Full Code Here

                                    boolean useSaved,
                                    boolean useSparseWeights)
  {
    Pipe p = makeSpacePredictionPipe ();

    MEMM savedCRF;
    File f = new File("TestObject.obj");
    InstanceList instances = new InstanceList(p);
    instances.addThruPipe(new ArrayIterator(data));
    InstanceList[] lists = instances.split(new double[]{.5, .5});
    MEMM crf = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
    crf.addFullyConnectedStatesForLabels();
    if (useSparseWeights)
      crf.setWeightsDimensionAsIn(lists[0]);
    else
      crf.setWeightsDimensionDensely();
   
    MEMMTrainer memmt = new MEMMTrainer (crf);
    // memmt.setUseSparseWeights (useSparseWeights);
    if (testValueAndGradient) {
      Optimizable.ByGradientValue minable = memmt.getOptimizableMEMM(lists[0]);
      TestOptimizable.testValueAndGradient(minable);
    } else {
      System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));
      savedCRF = crf;
      System.out.println("Training serialized crf.");
      memmt.train(lists[0], 100);
      double preTrainAcc = crf.averageTokenAccuracy(lists[0]);
      double preTestAcc = crf.averageTokenAccuracy(lists[1]);
      System.out.println("Training Accuracy after training = " + preTrainAcc);
      System.out.println("Testing  Accuracy after training = " + preTestAcc);
      try {
        ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
        oos.writeObject(crf);
        oos.close();
      } catch (IOException e) {
        System.err.println("Exception writing file: " + e);
      }
      System.err.println("Wrote out CRF");
      // And read it back in
      if (useSaved) {
        crf = null;
        try {
          ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
          crf = (MEMM) ois.readObject();
          ois.close();
        } catch (IOException e) {
          System.err.println("Exception reading file: " + e);
        } catch (ClassNotFoundException cnfe) {
          System.err.println("Cound not find class reading in object: " + cnfe);
        }
        System.err.println("Read in CRF.");
        crf = savedCRF;

        double postTrainAcc = crf.averageTokenAccuracy(lists[0]);
        double postTestAcc = crf.averageTokenAccuracy(lists[1]);
        System.out.println("Training Accuracy after saving = " + postTrainAcc);
        System.out.println("Testing  Accuracy after saving = " + postTestAcc);

        assertEquals(postTrainAcc, preTrainAcc, 0.0001);
        assertEquals(postTestAcc, preTestAcc, 0.0001);
View Full Code Here

    InstanceList[] lists = instances.split (new java.util.Random (678), new double[]{.5, .5});

    // Compare 3 CRFs trained with addOrderNStates, and make sure
    // that having more features leads to a higher likelihood

    MEMM crf1 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
    crf1.addOrderNStates (lists [0],
                         new int[] { 1, },
                         new boolean[] { false, },
                         "START",
                         null,
                         null,
                         false);
    crf1.setWeightsDimensionAsIn(lists[0]);
    MEMMTrainer memmt1 = new MEMMTrainer (crf1);
    memmt1.train(lists [0]);


    MEMM crf2 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
    crf2.addOrderNStates (lists [0],
                           new int[] { 1, 2, },
                           new boolean[] { false, true },
                           "START",
                           null,
                           null,
                           false);
    crf2.setWeightsDimensionAsIn(lists[0]);
    MEMMTrainer memmt2 = new MEMMTrainer (crf2);
    memmt2.train(lists [0]);


    MEMM crf3 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
    crf3.addOrderNStates (lists [0],
                         new int[] { 1, 2, },
                         new boolean[] { false, false },
                         "START",
                         null,
                         null,
                         false);
    crf3.setWeightsDimensionAsIn(lists[0]);
    MEMMTrainer memmt3 = new MEMMTrainer (crf3);
    memmt3.train(lists [0]);

    // Prevent cached values
    double lik1 = getLikelihood (memmt1, lists[0]);
View Full Code Here

       new PrintInputAndTarget(),
    });
    InstanceList one = new InstanceList (p);
    String[] data = new String[] { "ABCDE", };
    one.addThruPipe (new ArrayIterator (data));
    MEMM crf = new MEMM (p, null);
    crf.addFullyConnectedStatesForLabels();
    crf.setWeightsDimensionAsIn (one);
    MEMMTrainer memmt = new MEMMTrainer (crf);
    MEMMTrainer.MEMMOptimizableByLabelLikelihood mcrf = memmt.getOptimizableMEMM(one);
    double[] params = new double[mcrf.getNumParameters()];
    for (int i = 0; i < params.length; i++) {
      params [i] = i;
    }
    mcrf.setParameters (params);
    crf.print ();
  }
View Full Code Here

TOP

Related Classes of cc.mallet.fst.MEMM$TransitionIterator

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.