/*
* Encog(tm) Core v3.3 - Java Version
* http://www.heatonresearch.com/encog/
* https://github.com/encog/encog-java-core
* Copyright 2008-2014 Heaton Research, Inc.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*
* For more information on Heaton Research copyrights, licenses
* and trademarks visit:
* http://www.heatonresearch.com/copyright
*/
package org.encog.app.analyst;
import java.io.File;
import java.io.FileInputStream;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.URL;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.zip.GZIPInputStream;
import org.encog.app.analyst.analyze.PerformAnalysis;
import org.encog.app.analyst.commands.Cmd;
import org.encog.app.analyst.commands.CmdBalance;
import org.encog.app.analyst.commands.CmdCluster;
import org.encog.app.analyst.commands.CmdCode;
import org.encog.app.analyst.commands.CmdCreate;
import org.encog.app.analyst.commands.CmdEvaluate;
import org.encog.app.analyst.commands.CmdEvaluateRaw;
import org.encog.app.analyst.commands.CmdGenerate;
import org.encog.app.analyst.commands.CmdNormalize;
import org.encog.app.analyst.commands.CmdProcess;
import org.encog.app.analyst.commands.CmdRandomize;
import org.encog.app.analyst.commands.CmdReset;
import org.encog.app.analyst.commands.CmdSegregate;
import org.encog.app.analyst.commands.CmdSet;
import org.encog.app.analyst.commands.CmdTrain;
import org.encog.app.analyst.script.AnalystScript;
import org.encog.app.analyst.script.normalize.AnalystField;
import org.encog.app.analyst.script.prop.ScriptProperties;
import org.encog.app.analyst.script.task.AnalystTask;
import org.encog.app.analyst.util.AnalystUtility;
import org.encog.app.analyst.wizard.AnalystWizard;
import org.encog.app.quant.QuantTask;
import org.encog.bot.BotUtil;
import org.encog.ml.MLMethod;
import org.encog.ml.bayesian.BayesianNetwork;
import org.encog.ml.train.MLTrain;
import org.encog.util.Format;
import org.encog.util.logging.EncogLogging;
/**
* The Encog Analyst runs Encog Analyst Script files (EGA) to perform many
* common machine learning tasks. It is very much like Maven or ANT for Encog.
* Encog analyst files are made up of configuration information and tasks. Tasks
* are series of commands that make use of the configuration information to
* process CSV files.
*
*
*/
public class EncogAnalyst {
/**
* The name of the task that SHOULD everything.
*/
public static final String TASK_FULL = "task-full";
/**
* The analyst script.
*/
private final AnalystScript script = new AnalystScript();
/**
* The listeners.
*/
private final List<AnalystListener> listeners
= new ArrayList<AnalystListener>();
/**
* The update time for a download.
*/
public static final int UPDATE_TIME = 10;
/**
* The current task.
*/
private QuantTask currentQuantTask = null;
/**
* The commands.
*/
private final Map<String, Cmd> commands = new HashMap<String, Cmd>();
/**
* The max iterations, -1 unlimited.
*/
private int maxIteration = -1;
/**
* Holds a copy of the original property data, used to revert.
*/
private Map<String, String> revertData;
/**
* The method currently being trained, or null if that method should not
* be modified, or we are not training.
*/
private MLMethod method;
private final AnalystUtility utility = new AnalystUtility(this);
/**
* Construct the Encog analyst.
*/
public EncogAnalyst() {
addCommand(new CmdCreate(this));
addCommand(new CmdEvaluate(this));
addCommand(new CmdEvaluateRaw(this));
addCommand(new CmdGenerate(this));
addCommand(new CmdNormalize(this));
addCommand(new CmdRandomize(this));
addCommand(new CmdSegregate(this));
addCommand(new CmdTrain(this));
addCommand(new CmdBalance(this));
addCommand(new CmdSet(this));
addCommand(new CmdReset(this));
addCommand(new CmdCluster(this));
addCommand(new CmdCode(this));
addCommand(new CmdProcess(this));
}
/**
* Add a listener.
* @param listener The listener to add.
*/
public void addAnalystListener(final AnalystListener listener) {
this.listeners.add(listener);
}
/**
* Add a command.
* @param cmd The command to add.
*/
public void addCommand(final Cmd cmd) {
this.commands.put(cmd.getName(), cmd);
}
/**
* Analyze the specified file. Used by the wizard.
* @param file The file to analyze.
* @param headers True if headers are present.
* @param format The format of the file.
*/
public void analyze(final File file, final boolean headers,
final AnalystFileFormat format) {
this.script.getProperties().setFilename(AnalystWizard.FILE_RAW,
file.toString());
this.script.getProperties().setProperty(
ScriptProperties.SETUP_CONFIG_INPUT_HEADERS, headers);
final PerformAnalysis a = new PerformAnalysis(this.script,
file.toString(), headers, format);
a.process(this);
}
/**
* Analyze the specified file. Used by the wizard.
* @param file The file to analyze.
* @param headers True if headers are present.
* @param format The format of the file.
*/
public void reanalyze(final File file, final boolean headers,
final AnalystFileFormat format) {
final PerformAnalysis a = new PerformAnalysis(this.script,
file.toString(), headers, format);
a.process(this);
}
/**
* Determine the input count. This is the actual number of columns.
* @return The input count.
*/
public int determineInputCount() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.isInput() && !field.isIgnored()) {
result += field.getColumnsNeeded();
}
}
return result;
}
/**
* Determine the input field count, the fields are higher-level
* than columns.
* @return The input field count.
*/
public int determineInputFieldCount() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.isInput() && !field.isIgnored()) {
result++;
}
}
return result;
}
/**
* Determine the output count, this is the number of output
* columns needed.
* @return The output count.
*/
public int determineOutputCount() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.isOutput() && !field.isIgnored()) {
result += field.getColumnsNeeded();
}
}
return result;
}
/**
* Determine the number of output fields. Fields are higher
* level than columns.
* @return The output field count.
*/
public int determineOutputFieldCount() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.isOutput() && !field.isIgnored()) {
result++;
}
}
if( this.method instanceof BayesianNetwork ) {
result++;
}
return result;
}
/**
* Determine how many unique columns there are. Timeslices are not
* counted multiple times.
* @return The number of columns.
*/
public int determineUniqueColumns() {
final Map<String, Object> used = new HashMap<String, Object>();
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (!field.isIgnored()) {
final String name = field.getName();
if (!used.containsKey(name)) {
result += field.getColumnsNeeded();
used.put(name, null);
}
}
}
return result;
}
/**
* Determine the unique input field count. Timeslices are not
* counted multiple times.
* @return The number of unique input fields.
*/
public int determineUniqueInputFieldCount() {
final Map<String, Object> map = new HashMap<String, Object>();
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (!map.containsKey(field.getName())) {
if (field.isInput() && !field.isIgnored()) {
result++;
map.put(field.getName(), null);
}
}
}
return result;
}
/**
* Determine the total input field count, minus ignored fields.
* @return The number of unique input fields.
*/
public int determineTotalInputFieldCount() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.isInput() && !field.isIgnored()) {
result+=field.getColumnsNeeded();
}
}
return result;
}
/**
* Determine the unique output field count. Do not count timeslices
* multiple times.
* @return The unique output field count.
*/
public int determineUniqueOutputFieldCount() {
final Map<String, Object> map = new HashMap<String, Object>();
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (!map.containsKey(field.getName())) {
if (field.isOutput() && !field.isIgnored()) {
result++;
}
map.put(field.getName(), null);
}
}
return result;
}
/**
* Download a raw file from the Internet.
*/
public void download() {
final URL sourceURL = this.script.getProperties().getPropertyURL(
ScriptProperties.HEADER_DATASOURCE_SOURCE_FILE);
final String rawFile = this.script.getProperties().getPropertyFile(
ScriptProperties.HEADER_DATASOURCE_RAW_FILE);
final File rawFilename = getScript().resolveFilename(rawFile);
if (!rawFilename.exists()) {
downloadPage(sourceURL, rawFilename);
}
}
/**
* Down load a file from the specified URL, uncompress if needed.
* @param url THe URL.
* @param file The file to down load into.
*/
private void downloadPage(final URL url, final File file) {
FileOutputStream fos = null;
InputStream is = null;
FileInputStream fis = null;
GZIPInputStream gis = null;
try {
// download the URL
long size = 0;
final byte[] buffer = new byte[BotUtil.BUFFER_SIZE];
final File tempFile = new File(file.getParentFile(), "temp.tmp");
int length;
int lastUpdate = 0;
fos = new FileOutputStream(tempFile);
is = url.openStream();
do {
length = is.read(buffer);
if (length >= 0) {
fos.write(buffer, 0, length);
size += length;
}
if (lastUpdate > UPDATE_TIME) {
report(0, (int) (size / Format.MEMORY_MEG),
"Downloading... " + Format.formatMemory(size));
lastUpdate = 0;
}
lastUpdate++;
} while (length >= 0);
fos.close();
fos = null;
// unzip if needed
if (url.toString().toLowerCase().endsWith(".gz")) {
fis = new FileInputStream(tempFile);
gis = new GZIPInputStream(fis);
fos = new FileOutputStream(file);
size = 0;
lastUpdate = 0;
do {
length = gis.read(buffer);
if (length >= 0) {
fos.write(buffer, 0, length);
size += length;
}
if (lastUpdate > UPDATE_TIME) {
report(0, (int) (size / Format.MEMORY_MEG),
"Uncompressing... " + Format.formatMemory(size));
lastUpdate = 0;
}
lastUpdate++;
} while (length >= 0);
tempFile.delete();
} else {
// rename the temp file to the actual file
file.delete();
tempFile.renameTo(file);
}
} catch (final IOException e) {
throw new AnalystError(e);
} finally {
if( fos!=null ) {
try {
fos.close();
} catch (IOException e) {
EncogLogging.log(e);
}
}
if( is!=null ) {
try {
is.close();
} catch (IOException e) {
EncogLogging.log(e);
}
}
if( fis!=null ) {
try {
fis.close();
} catch (IOException e) {
EncogLogging.log(e);
}
}
if( gis!=null ) {
try {
gis.close();
} catch (IOException e) {
EncogLogging.log(e);
}
}
}
}
/**
* Execute a task.
* @param task The task to execute.
*/
public void executeTask(final AnalystTask task) {
final int total = task.getLines().size();
int current = 1;
for (String line : task.getLines()) {
EncogLogging.log(EncogLogging.LEVEL_DEBUG,
"Execute analyst line: " + line);
reportCommandBegin(total, current, line);
line = line.trim();
boolean canceled = false;
String command;
String args;
final String line2 = line.trim();
final int index = line2.indexOf(' ');
if (index != -1) {
command = line2.substring(0, index).toUpperCase();
args = line2.substring(index + 1);
} else {
command = line2.toUpperCase();
args = "";
}
final Cmd cmd = this.commands.get(command);
if (cmd != null) {
canceled = cmd.executeCommand(args);
} else {
throw new AnalystError("Unknown Command: " + line);
}
reportCommandEnd(canceled);
setCurrentQuantTask(null);
current++;
if (shouldStopAll()) {
break;
}
}
}
/**
* Execute a task.
* @param name The name of the task to execute.
*/
public void executeTask(final String name) {
EncogLogging.log(EncogLogging.LEVEL_INFO,
"Analyst execute task:" + name);
final AnalystTask task = this.script.getTask(name);
if (task == null) {
throw new AnalystError("Can't find task: " + name);
}
executeTask(task);
}
/**
* @return The lag depth.
*/
public int getLagDepth() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.getTimeSlice() < 0) {
result = Math.max(result, Math.abs(field.getTimeSlice()));
}
}
return result;
}
/**
* @return The lead depth.
*/
public int getLeadDepth() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.getTimeSlice() > 0) {
result = Math.max(result, field.getTimeSlice());
}
}
return result;
}
/**
* @return the listeners
*/
public List<AnalystListener> getListeners() {
return this.listeners;
}
/**
* @return The max iterations.
*/
public int getMaxIteration() {
return this.maxIteration;
}
/**
* @return The reverted data.
*/
public Map<String, String> getRevertData() {
return this.revertData;
}
/**
* @return the script
*/
public AnalystScript getScript() {
return this.script;
}
/**
* Load the specified script file.
* @param file The file to load.
*/
public void load(final File file) {
InputStream fis = null;
this.script.setBasePath(file.getParent());
try {
fis = new FileInputStream(file);
load(fis);
} catch (final IOException ex) {
throw new AnalystError(ex);
} finally {
if (fis != null) {
try {
fis.close();
} catch (final IOException e) {
throw new AnalystError(e);
}
}
}
}
/**
* Load from an input stream.
* @param stream The stream to load from.
*/
public void load(final InputStream stream) {
this.script.load(stream);
this.revertData = this.script.getProperties().prepareRevert();
}
/**
* Load from the specified filename.
* @param filename The filename to load from.
*/
public void load(final String filename) {
load(new File(filename));
}
/**
* Remove a listener.
* @param listener The listener to remove.
*/
public void removeAnalystListener(final AnalystListener listener) {
this.listeners.remove(listener);
}
/**
* Report progress.
* @param total The total units.
* @param current The current unit.
* @param message The message.
*/
private void report(final int total, final int current,
final String message) {
for (final AnalystListener listener : this.listeners) {
listener.report(total, current, message);
}
}
/**
* Report a command has begin.
* @param total The total units.
* @param current The current unit.
* @param name The command name.
*/
private void reportCommandBegin(final int total, final int current,
final String name) {
for (final AnalystListener listener : this.listeners) {
listener.reportCommandBegin(total, current, name);
}
}
/**
* Report a command has ended.
* @param canceled Was the command canceled.
*/
private void reportCommandEnd(final boolean canceled) {
for (final AnalystListener listener : this.listeners) {
listener.reportCommandEnd(canceled);
}
}
/**
* Report training.
* @param train The trainer.
*/
public void reportTraining(final MLTrain train) {
for (final AnalystListener listener : this.listeners) {
listener.reportTraining(train);
}
}
/**
* Report that training has begun.
*/
public void reportTrainingBegin() {
for (final AnalystListener listener : this.listeners) {
listener.reportTrainingBegin();
}
}
/**
* Report that training has ended.
*/
public void reportTrainingEnd() {
for (final AnalystListener listener : this.listeners) {
listener.reportTrainingEnd();
}
}
/**
* Save the script to a file.
* @param file The file to save to.
*/
public void save(final File file) {
OutputStream fos = null;
try {
this.script.setBasePath(file.getParent());
fos = new FileOutputStream(file);
save(fos);
} catch (final IOException ex) {
throw new AnalystError(ex);
} finally {
if (fos != null) {
try {
fos.close();
} catch (final IOException e) {
throw new AnalystError(e);
}
}
}
}
/**
* Save the script to a stream.
* @param stream The stream to save to.
*/
public void save(final OutputStream stream) {
this.script.save(stream);
}
/**
* Save the script to a filename.
* @param filename The filename to save to.
*/
public void save(final String filename) {
save(new File(filename));
}
/**
* Set the current task.
* @param task The current task.
*/
public synchronized void setCurrentQuantTask(final QuantTask task) {
this.currentQuantTask = task;
}
/**
* Set the max iterations.
* @param i The value for max iterations.
*/
public void setMaxIteration(final int i) {
this.maxIteration = i;
}
/**
* Should all commands be stopped.
* @return True, if all commands should be stopped.
*/
private boolean shouldStopAll() {
for (final AnalystListener listener : this.listeners) {
if (listener.shouldShutDown()) {
return true;
}
}
return false;
}
/**
* Should the current command be stopped.
* @return True if the current command should be stopped.
*/
public boolean shouldStopCommand() {
for (final AnalystListener listener : this.listeners) {
if (listener.shouldStopCommand()) {
return true;
}
}
return false;
}
/**
* Stop the current task.
*/
public synchronized void stopCurrentTask() {
if (this.currentQuantTask != null) {
this.currentQuantTask.requestStop();
}
}
/**
* @return True, if any field has a time slice.
*/
public boolean isTimeSeries() {
for (AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (field.getTimeSlice() != 0) {
return true;
}
}
return false;
}
/**
* @return the method
*/
public MLMethod getMethod() {
return method;
}
/**
* @param method the method to set
*/
public void setMethod(MLMethod method) {
this.method = method;
}
public int determineTotalColumns() {
int result = 0;
for (final AnalystField field : this.script.getNormalize()
.getNormalizedFields()) {
if (!field.isIgnored()) {
result += field.getColumnsNeeded();
}
}
return result;
}
public int determineMaxTimeSlice() {
int result = Integer.MIN_VALUE;
for(AnalystField field: this.getScript().getNormalize().getNormalizedFields()) {
result = Math.max(result, field.getTimeSlice());
}
return result;
}
public int determineMinTimeSlice() {
int result = Integer.MAX_VALUE;
for(AnalystField field: this.getScript().getNormalize().getNormalizedFields()) {
result = Math.min(result, field.getTimeSlice());
}
return result;
}
/**
* @return the utility
*/
public AnalystUtility getUtility() {
return this.utility;
}
}