Package org.encog.ml.train

Examples of org.encog.ml.train.MLTrain


  public static double trainNetwork(final String what,
      final BasicNetwork network, final MLDataSet trainingSet) {
    // train the neural network
    CalculateScore score = new TrainingSetScore(trainingSet);
    final MLTrain trainAlt = new NeuralSimulatedAnnealing(
        network, score, 10, 2, 100);

    final MLTrain trainMain = new Backpropagation(network, trainingSet,0.000001, 0.0);

    ((Propagation)trainMain).setNumThreads(1);
    final StopTrainingStrategy stop = new StopTrainingStrategy();
    trainMain.addStrategy(new Greedy());
    trainMain.addStrategy(new HybridStrategy(trainAlt));
    trainMain.addStrategy(stop);

    int epoch = 0;
    while (!stop.shouldStop()) {
      trainMain.iteration();
      System.out.println("Training " + what + ", Epoch #" + epoch
          + " Error:" + trainMain.getError());
      epoch++;
    }
    return trainMain.getError();
  }
View Full Code Here


    final Stopwatch watch = new Stopwatch();
    try {
      watch.start();

      this.currentJob.createTrainer(this.manager.isSingleThreaded());
      final MLTrain train = this.currentJob.getTrain();
      int interation = 1;

      while (this.currentJob.shouldContinue()) {
        train.iteration();
        interation++;
      }
      watch.stop();
    } catch (final Throwable t) {
      this.currentJob.setError(t);
View Full Code Here

    return network;
  }
 
  public void train(BasicNetwork network,MLDataSet training)
  {
    final MLTrain train = new ResilientPropagation(network, training);

    int epoch = 1;

    do {
      train.iteration();
      System.out
          .println("Epoch #" + epoch + " Error:" + train.getError());
      epoch++;
    } while(train.getError() > MAX_ERROR);
  }
View Full Code Here

   *            The error level to train to.
   */
  public static void trainToError(final MLMethod method,
      final MLDataSet dataSet, final double error) {

    MLTrain train;

    if (method instanceof SVM) {
      train = new SVMTrain((SVM)method, dataSet);
    } else {
      train = new ResilientPropagation((ContainsFlat)method, dataSet);
View Full Code Here

        XORSQL.SQL_URL,
        XORSQL.SQL_UID,
        XORSQL.SQL_PWD);
   
    // train the neural network
    final MLTrain train = new ResilientPropagation(network, trainingSet);
    // reset if improve is less than 1% over 5 cycles
    train.addStrategy(new RequiredImprovementStrategy(5));
   
    int epoch = 1;

    do {
      train.iteration();
      System.out
          .println("Epoch #" + epoch + " Error:" + train.getError());
      epoch++;
    } while(train.getError() > 0.01);

    // test the neural network
    System.out.println("Neural Network Results:");
    for(MLDataPair pair: trainingSet ) {
      final MLData output = network.compute(pair.getInput());
View Full Code Here

        //Smooth training data provides true values for the provided input dimensions.
        create2DSmoothTainingDataGit();

        //Create the training set and train.
        MLDataSet trainingSet = new BasicMLDataSet(INPUT, IDEAL);
        MLTrain train = new SVDTraining(network, trainingSet);

        //SVD is a single step solve
        int epoch = 1;
        do
        {
            train.iteration();
            System.out.println("Epoch #" + epoch + " Error:" + train.getError());
            epoch++;
        } while ((epoch < 1) && (train.getError() > 0.001));

        // test the neural network
        System.out.println("Neural Network Results:");

        //Create a testing array which may be to a higher resoltion than the original training data
View Full Code Here

    pattern.setOutputNeurons(outputNeurons);
    BasicNetwork network = (BasicNetwork)pattern.generate();
   
    // train it
    MLDataSet training = generateTraining();
    MLTrain train = new TrainAdaline(network,training,0.01);
   
    int epoch = 1;
    do {
      train.iteration();
      System.out
          .println("Epoch #" + epoch + " Error:" + train.getError());
      epoch++;
    } while(train.getError() > 0.01);
   
    //
    System.out.println("Error:" + network.calculateError(training));
   
    // test it
View Full Code Here

    network.reset();

    MLDataSet trainingSet = new BasicMLDataSet(input, output);

    // train the neural network
    MLTrain train = new Backpropagation(network, trainingSet, 0.7, 0.7);

    Stopwatch sw = new Stopwatch();
    sw.start();
    // run epoch of learning procedure
    for (int i = 0; i < ITERATIONS; i++) {
      train.iteration();
    }
    sw.stop();

    return sw.getElapsedMilliseconds();
  }
View Full Code Here

    return network;
  }

  public void train(BasicNetwork network, MLDataSet training) {
    final FoldedDataSet folded = new FoldedDataSet(training);
    final MLTrain train = new ResilientPropagation(network, folded);
    final CrossValidationKFold trainFolded = new CrossValidationKFold(train,4);

    int epoch = 1;

    do {
View Full Code Here

    // second, create the data set   
    MLDataSet dataSet = new BasicMLDataSet(XOR_INPUT, XOR_IDEAL);
   
    // third, create the trainer
    MLTrainFactory trainFactory = new MLTrainFactory()
    MLTrain train = trainFactory.create(method,dataSet,trainerName,trainerArgs);       
    // reset if improve is less than 1% over 5 cycles
    if( method instanceof MLResettable && !(train instanceof ManhattanPropagation) ) {
      train.addStrategy(new RequiredImprovementStrategy(500));
    }

    // fourth, train and evaluate.
    EncogUtility.trainToError(train, 0.01);
    EncogUtility.evaluate((MLRegression)method, dataSet);
View Full Code Here

TOP

Related Classes of org.encog.ml.train.MLTrain

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.