Package cc.mallet.types

Examples of cc.mallet.types.InstanceList$CrossValidationIterator


   * @exception Exception if an error occurs
   */
  public static void main (String[] args) throws Exception
  {
    Reader trainingFile = null, testFile = null;
    InstanceList trainingData = null, testData = null;
    int numEvaluations = 0;
    int iterationsBetweenEvals = 16;
    int restArgs = commandOptions.processOptions(args);
    if (restArgs == args.length)
    {
      commandOptions.printUsage(true);
      throw new IllegalArgumentException("Missing data file(s)");
    }
    if (trainOption.value)
    {
      trainingFile = new FileReader(new File(args[restArgs]));
      if (testOption.value != null && restArgs < args.length - 1)
        testFile = new FileReader(new File(args[restArgs+1]));
    } else
      testFile = new FileReader(new File(args[restArgs]));

    Pipe p = null;
    CRF crf = null;
    TransducerEvaluator eval = null;
    if (continueTrainingOption.value || !trainOption.value) {
      if (modelOption.value == null)
      {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Missing model file option");
      }
      ObjectInputStream s =
        new ObjectInputStream(new FileInputStream(modelOption.value));
      crf = (CRF) s.readObject();
      s.close();
      p = crf.getInputPipe();
    }
    else {
      p = new SimpleTaggerSentence2FeatureVectorSequence();
      p.getTargetAlphabet().lookupIndex(defaultOption.value);
    }


    if (trainOption.value)
    {
      p.setTargetProcessing(true);
      trainingData = new InstanceList(p);
      trainingData.addThruPipe(
          new LineGroupIterator(trainingFile,
            Pattern.compile("^\\s*$"), true));
      logger.info
        ("Number of features in training data: "+p.getDataAlphabet().size());
      if (testOption.value != null)
      {
        if (testFile != null)
        {
          testData = new InstanceList(p);
          testData.addThruPipe(
              new LineGroupIterator(testFile,
                Pattern.compile("^\\s*$"), true));
        }
        else
        {
          Random r = new Random (randomSeedOption.value);
          InstanceList[] trainingLists =
            trainingData.split(
                r, new double[] {trainingFractionOption.value,
                  1-trainingFractionOption.value});
          trainingData = trainingLists[0];
          testData = trainingLists[1];
        }
      }
    } else if (testOption.value != null)
    {
      p.setTargetProcessing(true);
      testData = new InstanceList(p);
      testData.addThruPipe(
          new LineGroupIterator(testFile,
            Pattern.compile("^\\s*$"), true));
    } else
    {
      p.setTargetProcessing(false);
      testData = new InstanceList(p);
      testData.addThruPipe(
          new LineGroupIterator(testFile,
            Pattern.compile("^\\s*$"), true));
    }
    logger.info ("Number of predicates: "+p.getDataAlphabet().size());
View Full Code Here


   * @exception Exception if an error occurs
   */
  public static void main (String[] args) throws Exception
  {
    Reader trainingFile = null, testFile = null;
    InstanceList trainingData = null, testData = null;
    int numEvaluations = 0;
    int iterationsBetweenEvals = 16;
    int restArgs = commandOptions.processOptions(args);

    Pipe p = null;
    CRF crf = null;
    TransducerEvaluator eval = null;
    if (continueTrainingOption.value || !trainOption.value) {
      if (modelOption.value == null)
      {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Missing model file option");
      }
      ObjectInputStream s =
        new ObjectInputStream(new FileInputStream(modelOption.value));
      crf = (CRF) s.readObject();
      s.close();
      p = crf.getInputPipe();
    }
    else {
      p = new SimpleTaggerSentence2FeatureVectorSequence();
      p.getTargetAlphabet().lookupIndex(defaultOption.value);
    }


    if (trainOption.value)
    {
      p.setTargetProcessing(true);
      trainingData = new InstanceList(p);
      trainingData.addThruPipe(
          new LineGroupIterator(trainingFile,
            Pattern.compile("^\\s*$"), true));
      logger.info
        ("Number of features in training data: "+p.getDataAlphabet().size());
      if (testOption.value != null)
      {
        if (testFile != null)
        {
          testData = new InstanceList(p);
          testData.addThruPipe(
              new LineGroupIterator(testFile,
                Pattern.compile("^\\s*$"), true));
        }
        else
        {
          Random r = new Random (randomSeedOption.value);
          InstanceList[] trainingLists =
            trainingData.split(
                r, new double[] {trainingFractionOption.value,
                  1-trainingFractionOption.value});
          trainingData = trainingLists[0];
          testData = trainingLists[1];
        }
      }
    } else if (testOption.value != null)
    {
      p.setTargetProcessing(true);
      testData = new InstanceList(p);
      testData.addThruPipe(
          new LineGroupIterator(testFile,
            Pattern.compile("^\\s*$"), true));
    } else
    {
      p.setTargetProcessing(false);
      testData = new InstanceList(p);
      //testData.addThruPipe(
      //    new LineGroupIterator(testFile,
      //      Pattern.compile("^\\s*$"), true));
    }
    //logger.info ("Number of predicates: "+p.getDataAlphabet().size());
   
   
    if (testOption.value != null)
    {
      if (testOption.value.startsWith("lab"))
        eval = new TokenAccuracyEvaluator(new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"});
      else if (testOption.value.startsWith("seg="))
      {
        String[] pairs = testOption.value.substring(4).split(",");
        if (pairs.length < 1)
        {
          commandOptions.printUsage(true);
          throw new IllegalArgumentException(
              "Missing segment start/continue labels: " + testOption.value);
        }
        String startTags[] = new String[pairs.length];
        String continueTags[] = new String[pairs.length];
        for (int i = 0; i < pairs.length; i++)
        {
          String[] pair = pairs[i].split("\\.");
          if (pair.length != 2)
          {
            commandOptions.printUsage(true);
            throw new
              IllegalArgumentException(
                  "Incorrectly-specified segment start and end labels: " +
                  pairs[i]);
          }
          startTags[i] = pair[0];
          continueTags[i] = pair[1];
        }
        eval = new MultiSegmentationEvaluator(new InstanceList[] {trainingData, testData}, new String[] {"Training", "Testing"},
            startTags, continueTags);
      }
      else
      {
        commandOptions.printUsage(true);
        throw new IllegalArgumentException("Invalid test option: " +
            testOption.value);
      }
    }
   
   
   
    if (p.isTargetProcessing())
    {
      Alphabet targets = p.getTargetAlphabet();
      StringBuffer buf = new StringBuffer("Labels:");
      for (int i = 0; i < targets.size(); i++)
        buf.append(" ").append(targets.lookupObject(i).toString());
      logger.info(buf.toString());
    }
    if (trainOption.value)
    {
      crf = train(trainingData, testData, eval,
          ordersOption.value, defaultOption.value,
          forbiddenOption.value, allowedOption.value,
          connectedOption.value, iterationsOption.value,
          gaussianVarianceOption.value, crf);
      if (modelOption.value != null)
      {
        ObjectOutputStream s =
          new ObjectOutputStream(new FileOutputStream(modelOption.value));
        s.writeObject(crf);
        s.close();
      }
    }
    else
    {
      if (crf == null)
      {
        if (modelOption.value == null)
        {
          commandOptions.printUsage(true);
          throw new IllegalArgumentException("Missing model file option");
        }
        ObjectInputStream s =
          new ObjectInputStream(new FileInputStream(modelOption.value));
        crf = (CRF) s.readObject();
        s.close();
      }
      if (eval != null)
        test(new NoopTransducerTrainer(crf), eval, testData);
      else
      {
        boolean includeInput = includeInputOption.value();
        Scanner scanner = new Scanner(System.in);
        Pattern pattern = Pattern.compile("^\\s*$");

        int nLines = 0;
        while (scanner.hasNextLine()) {
            String line = scanner.nextLine();
            line = line.replace('\t', '\n');
            testData = new InstanceList(p);
            testData.addThruPipe(new LineGroupIterator(new StringReader(line),
                                                       pattern, true));

            for (int i = 0; i < testData.size(); i++) {
                Sequence input = (Sequence)testData.get(i).getData();
View Full Code Here

      testSource = new LineGroupIterator (new FileReader (testFile.value), Pattern.compile ("^\\s*$"), true);
    } else {
      testSource = null;
    }

    InstanceList training = new InstanceList (pipe);
    training.addThruPipe (trainSource);
    InstanceList testing = new InstanceList (pipe);
    testing.addThruPipe (testSource);

    ACRF.Template[] tmpls = parseModelFile (modelFile.value);
    ACRFEvaluator eval = createEvaluator (evalOption.value);

    Inferencer inf = createInferencer (inferencerOption.value);
View Full Code Here

  CommandOption.Double(Vectors2FeatureConstraints.class, "majority-prob", "DOUBLE",
      false, 0.9, "Probability for majority labels when using heuristic target estimation.", null);

  public static void main(String[] args) {
    CommandOption.process(Vectors2FeatureConstraints.class, args);
    InstanceList list = InstanceList.load(vectorsFile.value)
   
    // Here we will assume that we use all labeled data available. 
    ArrayList<Integer> features = null;
    HashMap<Integer,ArrayList<Integer>> featuresAndLabels = null;

    // if a features file was specified, then load features from the file
    if (featuresFile.wasInvoked()) {
      if (fileContainsLabels(featuresFile.value)) {
        // better error message from dfrankow@gmail.com
        if (targets.value.equals("oracle")) {
          throw new RuntimeException("with --targets oracle, features file must be unlabeled");
        }
        featuresAndLabels = readFeaturesAndLabelsFromFile(featuresFile.value, list.getDataAlphabet(), list.getTargetAlphabet());
      }
      else {
        features = readFeaturesFromFile(featuresFile.value, list.getDataAlphabet());       
      }
    }
   
    // otherwise select features using specified method
    else {
      if (featureSelection.value.equals("infogain")) {
        features = FeatureConstraintUtil.selectFeaturesByInfoGain(list,numConstraints.value);
      }
      else if (featureSelection.value.equals("lda")) {
        try {
          ObjectInputStream ois = new ObjectInputStream(new FileInputStream(ldaFile.value));
          ParallelTopicModel lda = (ParallelTopicModel)ois.readObject();
          features = FeatureConstraintUtil.selectTopLDAFeatures(numConstraints.value, lda, list.getDataAlphabet());
        }
        catch (Exception e) {
          e.printStackTrace();
        }
      }
      else {
        throw new RuntimeException("Unsupported value for feature selection: " + featureSelection.value);
      }
    }
   
    // If the target method is oracle, then we do not need feature "labels".
    HashMap<Integer,double[]> constraints = null;
   
    if (targets.value.equals("none")) {
      constraints = new HashMap<Integer,double[]>();
      for (int fi : features) {    
        constraints.put(fi, null);
      }
    }
    else if (targets.value.equals("oracle")) {
      constraints = FeatureConstraintUtil.setTargetsUsingData(list, features);
    }
    else {
      // For other methods, we need to get feature labels, as
      // long as they haven't been already loaded from disk.
      if (featuresAndLabels == null) {
        featuresAndLabels = FeatureConstraintUtil.labelFeatures(list,features);
       
        for (int fi : featuresAndLabels.keySet()) {
          logger.info(list.getDataAlphabet().lookupObject(fi) + ":  ");
          for (int li : featuresAndLabels.get(fi)) {
            logger.info(list.getTargetAlphabet().lookupObject(li) + " ");
          }
        }
       
      }
      if (targets.value.equals("heuristic")) {
        constraints = FeatureConstraintUtil.setTargetsUsingHeuristic(featuresAndLabels,list.getTargetAlphabet().size(),majorityProb.value);
      }
      else if (targets.value.equals("voted")) {
        constraints = FeatureConstraintUtil.setTargetsUsingFeatureVoting(featuresAndLabels,list);
      }
      else {
        throw new RuntimeException("Unsupported value for targets: " + targets.value);
      }
    }
    writeConstraints(constraints,constraintsFile.value,list.getDataAlphabet(),list.getTargetAlphabet())
  }
View Full Code Here

        Integer numTopics = (Integer)input.get(0); // Number of topics to discover
        DataBag documents = (DataBag)input.get(1); // Documents, {(doc_id, text)}
        DataBag result = BagFactory.getInstance().newDefaultBag();

        InstanceList instances = new InstanceList(pipe);

        // Add the input databag as source data and run it through the pipe built
        // by the constructor.
        instances.addThruPipe(new DataBagSourceIterator(documents));

        // Create a model with numTopics, alpha_t = 0.01, beta_w = 0.01
        // Note that the first parameter is passed as the sum over topics, while
        // the second is the parameter for a single dimension of the Dirichlet prior.
        ParallelTopicModel model = new ParallelTopicModel(numTopics, 1.0, 0.01);
        model.addInstances(instances);
        model.setNumThreads(1); // Important, since this is being run in the reduce, just use one thread
        model.setTopicDisplay(0,0);
        model.setNumIterations(2000);
        model.estimate();

        // Get the results
        Alphabet dataAlphabet = instances.getDataAlphabet();
        ArrayList<TopicAssignment> assignments = model.getData();

        // Convert the results into comprehensible topics
        for (int topicNum = 0; topicNum < model.getNumTopics(); topicNum++) {
            TreeSet<IDSorter> sortedWords = model.getSortedWords().get(topicNum);
View Full Code Here

                this.corpus = corpus;
        }
               
        public void evaluate() {
            Pipe pipe = buildPipe();
            InstanceList instances = new InstanceList(pipe);
            for(Document document : corpus.getDocuments()) {
                Instance instance = new Instance(document.getDocumentString(),null,null,document.getDocumentString());
                instance.setData(document.getDocumentString());
                instances.addThruPipe(instance);
            }
            LDA.addInstances(instances);
            try {
                LDA.estimate();
            } catch (IOException e) {
View Full Code Here

                this.corpus = corpus;
        }
               
        public void evaluate() {
            Pipe pipe = buildPipe();
            InstanceList instances = new InstanceList(pipe);
            for(Document document : corpus.getDocuments()) {
                Instance instance = new Instance(document.getDocumentString(),null,null,document.getDocumentString());
                instance.setData(document.getDocumentString());
                instances.addThruPipe(instance);
            }
            LDA.addInstances(instances);
            try {
        LDA.sample(numIterations);
      } catch (IOException e) {
View Full Code Here

TOP

Related Classes of cc.mallet.types.InstanceList$CrossValidationIterator

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.