Package org.encog.app.analyst.csv

Source Code of org.encog.app.analyst.csv.AnalystEvaluateCSV

/*
* Encog(tm) Core v3.0 - Java Version
* http://www.heatonresearch.com/encog/
* http://code.google.com/p/encog-java/
* Copyright 2008-2011 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
*     http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*  
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.app.analyst.csv;

import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;

import org.encog.app.analyst.EncogAnalyst;
import org.encog.app.analyst.csv.basic.BasicFile;
import org.encog.app.analyst.csv.basic.LoadedRow;
import org.encog.app.analyst.csv.normalize.AnalystNormalizeCSV;
import org.encog.app.analyst.script.normalize.AnalystField;
import org.encog.app.analyst.util.CSVHeaders;
import org.encog.app.quant.QuantError;
import org.encog.ml.MLClassification;
import org.encog.ml.MLMethod;
import org.encog.ml.MLRegression;
import org.encog.ml.data.MLData;
import org.encog.ml.data.basic.BasicMLData;
import org.encog.util.arrayutil.ClassItem;
import org.encog.util.csv.CSVFormat;
import org.encog.util.csv.ReadCSV;

/**
* Used by the analyst to evaluate a CSV file.
*
*/
public class AnalystEvaluateCSV extends BasicFile {

  /**
   * The analyst to use.
   */
  private EncogAnalyst analyst;
 
  /**
   * The number of columns in the file.
   */
  private int fileColumns;
 
  /**
   * The number of output columns.
   */
  private int outputColumns;
 
  /**
   * Used to handle time series.
   */
  private TimeSeriesUtil series;
 
  /**
   * The headers.
   */
  private CSVHeaders analystHeaders;

  /**
   *  Analyze the data. This counts the records and prepares the data to be
   * processed.
   * @param theAnalyst The analyst to use.
   * @param inputFile The input file.
   * @param headers True if headers are present.
   * @param format The format.
   */
  public final void analyze(final EncogAnalyst theAnalyst,
      final File inputFile,
      final boolean headers, final CSVFormat format) {
    this.setInputFilename(inputFile);
    setExpectInputHeaders(headers);
    setInputFormat(format);

    setAnalyzed(true);
    this.analyst = theAnalyst;

    performBasicCounts();
    this.fileColumns = this.getInputHeadings().length;
    this.outputColumns = this.analyst.determineOutputFieldCount();

    this.analystHeaders = new CSVHeaders(this.getInputHeadings());
    this.series = new TimeSeriesUtil(analyst,
        this.analystHeaders.getHeaders());

  }

  /**
   * Prepare the output file, write headers if needed.
   * @param outputFile The output file.
   * @param input The input count.
   * @param output The output count.
   * @return The file to write to.
   */
  private PrintWriter prepareOutputFile(final File outputFile,
      final int input, final int output) {
    try {
      final PrintWriter tw = new PrintWriter(new FileWriter(outputFile));

      // write headers, if needed
      if (isProduceOutputHeaders()) {
        final StringBuilder line = new StringBuilder();

        // handle provided fields, not all may be used, but all should
        // be displayed
        for (final String heading : this.getInputHeadings()) {
          BasicFile.appendSeparator(line, getOutputFormat());
          line.append("\"");
          line.append(heading);
          line.append("\"");
        }

        // now add the output fields that will be generated
        for (final AnalystField field : this.analyst.getScript()
            .getNormalize().getNormalizedFields()) {
          if (field.isOutput() && !field.isIgnored()) {
            BasicFile.appendSeparator(line, getOutputFormat());
            line.append("\"Output:");
            line.append(CSVHeaders.tagColumn(field.getName(), 0,
                field.getTimeSlice(), false));
            line.append("\"");
          }
        }

        tw.println(line.toString());
      }

      return tw;

    } catch (final IOException e) {
      throw new QuantError(e);
    }
  }

  /**
   * Process the file.
   * @param outputFile The output file.
   * @param method THe method to use.
   */
  public final void process(final File outputFile,
      final MLMethod method) {

    final ReadCSV csv = new ReadCSV(getInputFilename().toString(),
        isExpectInputHeaders(), getInputFormat());

    MLData output = null;

    final int outputLength = this.analyst.determineUniqueColumns();

    final PrintWriter tw = this.prepareOutputFile(outputFile, this.analyst
        .getScript().getNormalize().countActiveFields() - 1, 1);

    resetStatus();
    while (csv.next()) {
      updateStatus(false);
      final LoadedRow row = new LoadedRow(csv, this.outputColumns);

      double[] inputArray = AnalystNormalizeCSV.extractFields(analyst,
          this.analystHeaders, csv, outputLength, false);
      if (this.series.getTotalDepth() > 1) {
        inputArray = this.series.process(inputArray);
      }

      if (inputArray != null) {
        final MLData input = new BasicMLData(inputArray);

        // evaluation data
        if ((method instanceof MLClassification)
            && !(method instanceof MLRegression)) {
          // classification only?
          output = new BasicMLData(1);
          output.setData(0,
              ((MLClassification) method).classify(input));
        } else {
          // regression
          output = ((MLRegression) method).compute(input);
        }

        // skip file data
        int index = this.fileColumns;
        int outputIndex = 0;

        // display output
        for (final AnalystField field : analyst.getScript()
            .getNormalize().getNormalizedFields()) {
          if (this.analystHeaders.find(field.getName()) != -1) {

            if (field.isOutput()) {
              if (field.isClassify()) {
                // classification
                final ClassItem cls = field.determineClass(
                    outputIndex, output.getData());
                outputIndex += field.getColumnsNeeded();
                if (cls == null) {
                  row.getData()[index++] = "?Unknown?";
                } else {
                  row.getData()[index++] = cls.getName();
                }
              } else {
                // regression
                double n = output.getData(outputIndex++);
                n = field.deNormalize(n);
                row.getData()[index++] = getInputFormat()
                    .format(n, getPrecision());
              }
            }
          }
        }
      }

      writeRow(tw, row);
    }
    reportDone(false);
    tw.close();
    csv.close();
  }
}
TOP

Related Classes of org.encog.app.analyst.csv.AnalystEvaluateCSV

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.