Package com.heatonresearch.aifh.general.data

Examples of com.heatonresearch.aifh.general.data.BasicData


        // Divide over the k sets.
        int leaveOutSet = 0;

        while (temp.size() > 0) {
            int idx = rnd.nextInt(temp.size());
            BasicData item = temp.get(idx);
            temp.remove(idx);

            this.folds.get(leaveOutSet).getValidationSet().add(item);
            for (int includeSet = 0; includeSet < this.folds.size(); includeSet++) {
                if (includeSet != leaveOutSet) {
View Full Code Here


        }

        String[] nextLine;

        while ((nextLine = reader.readNext()) != null) {
            BasicData data = new BasicData(TitanicConfig.InputFeatureCount, 1);

            String name = nextLine[nameIndex];
            String sex = nextLine[sexIndex];
            String embarked = nextLine[indexEmbarked];
            String id = nextLine[indexId];

            // Add record the passenger id, if requested
            if (ids != null) {
                ids.add(id);
            }

            boolean isMale = sex.equalsIgnoreCase("male");


            // age
            double age;

            // do we have an age for this person?
            if (nextLine[ageIndex].length() == 0) {
                // age is missing, interpolate using name
                if (name.contains("Master.")) {
                    age = stats.getMeanMaster().calculate();
                } else if (name.contains("Mr.")) {
                    age = stats.getMeanMr().calculate();
                } else if (name.contains("Miss.") || name.contains("Mlle.")) {
                    age = stats.getMeanMiss().calculate();
                } else if (name.contains("Mrs.") || name.contains("Mme.")) {
                    age = stats.getMeanMrs().calculate();
                } else if (name.contains("Col.") || name.contains("Capt.") || name.contains("Major.")) {
                    age = stats.getMeanMiss().calculate();
                } else if (name.contains("Countess.") || name.contains("Lady.") || name.contains("Sir.") || name.contains("Don.") || name.contains("Dona.") || name.contains("Jonkheer.")) {
                    age = stats.getMeanNobility().calculate();
                } else if (name.contains("Dr.")) {
                    age = stats.getMeanDr().calculate();
                } else if (name.contains("Rev.")) {
                    age = stats.getMeanClergy().calculate();
                } else {
                    if (isMale) {
                        age = stats.getMeanMale().calculate();
                    } else {
                        age = stats.getMeanFemale().calculate();
                    }
                }
            } else {
                age = Double.parseDouble(nextLine[ageIndex]);

            }
            data.getInput()[0] = rangeNormalize(age, 0, 100, inputLow, inputHigh);

            // sex-male
            data.getInput()[1] = isMale ? inputHigh : inputLow;

            // pclass
            double pclass = Double.parseDouble(nextLine[indexPclass]);
            data.getInput()[2] = rangeNormalize(pclass, 1, 3, inputLow, inputHigh);

            // sibsp
            double sibsp = Double.parseDouble(nextLine[indexSibsp]);
            data.getInput()[3] = rangeNormalize(sibsp, 0, 10, inputLow, inputHigh);

            // parch
            double parch = Double.parseDouble(nextLine[indexParch]);
            data.getInput()[4] = rangeNormalize(parch, 0, 10, inputLow, inputHigh);

            // fare
            String strFare = nextLine[indexFare];
            double fare;

            if (strFare.length() == 0) {
                if (((int) pclass) == 1) {
                    fare = stats.getMeanFare1().calculate();
                } else if (((int) pclass) == 2) {
                    fare = stats.getMeanFare2().calculate();
                } else if (((int) pclass) == 3) {
                    fare = stats.getMeanFare3().calculate();
                } else {
                    // should not happen, we would have a class other than 1,2,3.
                    // however, if that DID happen, use the median class (2).
                    fare = stats.getMeanFare2().calculate();
                }
            } else {
                fare = Double.parseDouble(nextLine[indexFare]);
            }
            data.getInput()[5] = rangeNormalize(fare, 0, 500, inputLow, inputHigh);

            // embarked-c
            data.getInput()[6] = embarked.trim().equalsIgnoreCase("c") ? inputHigh : inputLow;

            // embarked-q
            data.getInput()[7] = embarked.trim().equalsIgnoreCase("q") ? inputHigh : inputLow;

            // embarked-s
            data.getInput()[8] = embarked.trim().equalsIgnoreCase("s") ? inputHigh : inputLow;

            // name-mil
            data.getInput()[9] = (name.contains("Col.") || name.contains("Capt.") || name.contains("Major.")) ? inputHigh : inputLow;

            // name-nobility
            data.getInput()[10] = (name.contains("Countess.") || name.contains("Lady.") || name.contains("Sir.") || name.contains("Don.") || name.contains("Dona.") || name.contains("Jonkheer.")) ? inputHigh : inputLow;

            // name-dr
            data.getInput()[11] = (name.contains("Dr.")) ? inputHigh : inputLow;


            // name-clergy
            data.getInput()[12] = (name.contains("Rev.")) ? inputHigh : inputLow;

            // add the new row
            result.add(data);

            // add survived, if it exists
            if (survivedIndex != -1) {
                int survived = Integer.parseInt(nextLine[survivedIndex]);
                data.getIdeal()[0] = (survived == 1) ? predictSurvive : predictPerish;
            }

        }

        return result;
View Full Code Here

        final Matrix xMatrix = new Matrix(rowCount, inputColCount + 1);
        final Matrix yMatrix = new Matrix(rowCount, 1);

        for (int row = 0; row < trainingData.size(); row++) {
            final BasicData dataRow = this.trainingData.get(row);
            final int colSize = dataRow.getInput().length;

            xMatrix.set(row, 0, 1);
            for (int col = 0; col < colSize; col++) {
                xMatrix.set(row, col + 1, dataRow.getInput()[col]);
            }
            yMatrix.set(row, 0, dataRow.getIdeal()[0]);
        }

        // Calculate the least squares solution
        final QRDecomposition qr = new QRDecomposition(xMatrix);
        final Matrix beta = qr.solve(yMatrix);

        double sum = 0.0;
        for (int i = 0; i < inputColCount; i++)
            sum += yMatrix.get(i, 0);
        final double mean = sum / inputColCount;

        for (int i = 0; i < inputColCount; i++) {
            final double dev = yMatrix.get(i, 0) - mean;
            sst += dev * dev;
        }

        final Matrix residuals = xMatrix.times(beta).minus(yMatrix);
        sse = residuals.norm2() * residuals.norm2();

        for (int i = 0; i < this.algorithm.getLongTermMemory().length; i++) {
            this.algorithm.getLongTermMemory()[i] = beta.get(i, 0);
        }

        // calculate error
        this.errorCalculation.clear();
        for (final BasicData dataRow : this.trainingData) {
            final double[] output = this.algorithm.computeRegression(dataRow.getInput());
            this.errorCalculation.updateError(output, dataRow.getIdeal(), 1.0);
        }
        this.error = this.errorCalculation.calculate();
    }
View Full Code Here

        final double[] errors = new double[rowCount];
        final double[] weights = new double[rowCount];
        final Matrix deltas;

        for (int i = 0; i < rowCount; i++) {
            final BasicData element = this.trainingData.get(i);

            working[i][0] = 1;
            for (int j = 0; j < element.getInput().length; j++)
                working[i][j + 1] = element.getInput()[j];
        }

        for (int i = 0; i < rowCount; i++) {
            final BasicData element = this.trainingData.get(i);
            final double y = this.algorithm.computeRegression(element.getInput())[0];
            errors[i] = y - element.getIdeal()[0];
            weights[i] = y * (1.0 - y);
        }

        for (int i = 0; i < gradient.getColumnDimension(); i++) {
            gradient.set(0, i, 0);
View Full Code Here

TOP

Related Classes of com.heatonresearch.aifh.general.data.BasicData

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.