Package cc.mallet.fst

Examples of cc.mallet.fst.MEMMTrainer


    InstanceList testing2 = new InstanceList (pipe2);
    testing2.addThruPipe (new ArrayIterator (data1));

    MEMM memm = new MEMM (pipe2, null);
    memm.addFullyConnectedStatesForLabels ();
    MEMMTrainer memmt = new MEMMTrainer (memm);
    TransducerEvaluator memmeval = new TokenAccuracyEvaluator (new InstanceList[] {training2, testing2}, new String[] {"Training2", "Testing2"});
    memmt.train (training2, 5);
    memmeval.evaluate(memmt);

    CRFExtractor extor2 = hackCrfExtor (memm);
    Extraction e2 = extor2.extract (new ArrayIterator (data1));
View Full Code Here


    MEMM memm = new MEMM (inputAlphabet, outputAlphabet);
    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    memm.addFullyConnectedStates(stateNames);
    MEMMTrainer memmt = new MEMMTrainer (memm);
    MEMMTrainer.MEMMOptimizableByLabelLikelihood omemm = memmt.getOptimizableMEMM (new InstanceList(null));
    TestOptimizable.testGetSetParameters(omemm);
  }
View Full Code Here

    MEMM memm = new MEMM (p, null);
    memm.addFullyConnectedStatesForLabels ();
    memm.addStartState();
    memm.setWeightsDimensionAsIn(training);
   
    MEMMTrainer memmt = new MEMMTrainer (memm);
//    memm.gatherTrainingSets (training); // ANNOYING: Need to set up per-instance training sets
    memmt.train (training, 1)// Set weights dimension, gathers training sets, etc.

//    memm.print();
//    memm.printGradient = true;
//    memm.printInstanceLists();

//    memm.setGaussianPriorVariance (Double.POSITIVE_INFINITY);
    Optimizable.ByGradientValue mcrf = memmt.getOptimizableMEMM(training);
    TestOptimizable.setNumComponents (150);
    TestOptimizable.testValueAndGradient (mcrf);
  }
View Full Code Here

    MEMM memm = new MEMM (p, null);
    memm.addFullyConnectedStatesForLabels ();
    memm.addStartState();
    memm.setWeightsDimensionAsIn(training);
    MEMMTrainer memmt = new MEMMTrainer (memm);
    memmt.train (training, 10);

    MEMM memm2 = (MEMM) TestSerializable.cloneViaSerialization (memm);

    Optimizable.ByGradientValue mcrf1 = memmt.getOptimizableMEMM(training);
    double val1 = mcrf1.getValue ();
    Optimizable.ByGradientValue mcrf2 = memmt.getOptimizableMEMM(training);
    double val2 = mcrf2.getValue ();

    assertEquals (val1, val2, 1e-5);
  }
View Full Code Here

    // Store the dictionary
    if (outputAlphabet == null) {
      System.err.println("Output dictionary null.");
    }
    MEMM crf = new MEMM(inputAlphabet, outputAlphabet);
    MEMMTrainer memmt = new MEMMTrainer (crf);

    String[] stateNames = new String[numStates];
    for (int i = 0; i < numStates; i++)
      stateNames[i] = "state" + i;
    MEMM saveCRF = crf;
    //inputAlphabet = (Feature.Alphabet) crf.getInputAlphabet();
    FeatureVectorSequence fvs = new FeatureVectorSequence(new FeatureVector[]{
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
      new FeatureVector(crf.getInputAlphabet(), new int[]{1, 2, 3}, new double[]{1, 1, 1}),
    });
    FeatureSequence ss = new FeatureSequence(crf.getOutputAlphabet(), new int[]{0, 1, 2, 3});
    InstanceList ilist = new InstanceList(null);
    ilist.add(fvs, ss, null, null);

    crf.addFullyConnectedStates(stateNames);

    try {
      ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f));
      oos.writeObject(crf);
      oos.close();
    } catch (IOException e) {
      System.err.println("Exception writing file: " + e);
    }
    System.err.println("Wrote out CRF");
    // And read it back in
    crf = null;
    try {
      ObjectInputStream ois = new ObjectInputStream(new FileInputStream(f));
      crf = (MEMM) ois.readObject();
      ois.close();
    } catch (IOException e) {
      System.err.println("Exception reading file: " + e);
    } catch (ClassNotFoundException cnfe) {
      System.err.println("Cound not find class reading in object: " + cnfe);
    }
    System.err.println("Read in CRF.");

    try {
      ObjectOutputStream oos = new ObjectOutputStream(new FileOutputStream(f2));
      oos.writeObject(crf);
      oos.close();
    } catch (IOException e) {
      System.err.println("Exception writing file: " + e);
    }
    System.err.println("Wrote out CRF");
    if (useSave == 1) {
      crf = saveCRF;
    }
//    MEMM.OptimizableCRF mcrf = crf.getMaximizableCRF(ilist);
    Optimizable.ByGradientValue mcrf = memmt.getOptimizableMEMM(ilist);

    double unconstrainedCost = new SumLatticeDefault (crf, fvs).getTotalWeight();
    double constrainedCost = new SumLatticeDefault (crf, fvs, ss).getTotalWeight();
    double minimizableCost = 0, minimizableGradientNorm = 0;
    double[] gradient = new double [mcrf.getNumParameters()];
View Full Code Here

    InstanceList[] lists = instances.split(new double[]{.5, .5});
    MEMM memm = new MEMM(p, p2);
    memm.addFullyConnectedStatesForLabels();
    memm.setWeightsDimensionAsIn(lists[0]);
   
    MEMMTrainer memmt = new MEMMTrainer (memm);
    if (testValueAndGradient) {
      Optimizable.ByGradientValue minable = memmt.getOptimizableMEMM(lists[0]);
      TestOptimizable.testValueAndGradient(minable);
    } else {
      System.out.println("Training Accuracy before training = " + memm.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy before training = " + memm.averageTokenAccuracy(lists[1]));
      System.out.println("Training...");
      memmt.train(lists[0], 1);
      System.out.println("Training Accuracy after training = " + memm.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy after training = " + memm.averageTokenAccuracy(lists[1]));
      System.out.println("Training results:");
      for (int i = 0; i < lists[0].size(); i++) {
        Instance inst = lists[0].get(i);
View Full Code Here

    if (useSparseWeights)
      crf.setWeightsDimensionAsIn(lists[0]);
    else
      crf.setWeightsDimensionDensely();
   
    MEMMTrainer memmt = new MEMMTrainer (crf);
    // memmt.setUseSparseWeights (useSparseWeights);
    if (testValueAndGradient) {
      Optimizable.ByGradientValue minable = memmt.getOptimizableMEMM(lists[0]);
      TestOptimizable.testValueAndGradient(minable);
    } else {
      System.out.println("Training Accuracy before training = " + crf.averageTokenAccuracy(lists[0]));
      System.out.println("Testing  Accuracy before training = " + crf.averageTokenAccuracy(lists[1]));
      savedCRF = crf;
      System.out.println("Training serialized crf.");
      memmt.train(lists[0], 100);
      double preTrainAcc = crf.averageTokenAccuracy(lists[0]);
      double preTestAcc = crf.averageTokenAccuracy(lists[1]);
      System.out.println("Training Accuracy after training = " + preTrainAcc);
      System.out.println("Testing  Accuracy after training = " + preTestAcc);
      try {
View Full Code Here

                         "START",
                         null,
                         null,
                         false);
    crf1.setWeightsDimensionAsIn(lists[0]);
    MEMMTrainer memmt1 = new MEMMTrainer (crf1);
    memmt1.train(lists [0]);


    MEMM crf2 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
    crf2.addOrderNStates (lists [0],
                           new int[] { 1, 2, },
                           new boolean[] { false, true },
                           "START",
                           null,
                           null,
                           false);
    crf2.setWeightsDimensionAsIn(lists[0]);
    MEMMTrainer memmt2 = new MEMMTrainer (crf2);
    memmt2.train(lists [0]);


    MEMM crf3 = new MEMM(p.getDataAlphabet(), p.getTargetAlphabet());
    crf3.addOrderNStates (lists [0],
                         new int[] { 1, 2, },
                         new boolean[] { false, false },
                         "START",
                         null,
                         null,
                         false);
    crf3.setWeightsDimensionAsIn(lists[0]);
    MEMMTrainer memmt3 = new MEMMTrainer (crf3);
    memmt3.train(lists [0]);

    // Prevent cached values
    double lik1 = getLikelihood (memmt1, lists[0]);
    double lik2 = getLikelihood (memmt2, lists[0]);
    double lik3 = getLikelihood (memmt3, lists[0]);
View Full Code Here

    String[] data = new String[] { "ABCDE", };
    one.addThruPipe (new ArrayIterator (data));
    MEMM crf = new MEMM (p, null);
    crf.addFullyConnectedStatesForLabels();
    crf.setWeightsDimensionAsIn (one);
    MEMMTrainer memmt = new MEMMTrainer (crf);
    MEMMTrainer.MEMMOptimizableByLabelLikelihood mcrf = memmt.getOptimizableMEMM(one);
    double[] params = new double[mcrf.getNumParameters()];
    for (int i = 0; i < params.length; i++) {
      params [i] = i;
    }
    mcrf.setParameters (params);
View Full Code Here

TOP

Related Classes of cc.mallet.fst.MEMMTrainer

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.