Package hivemall.io

Examples of hivemall.io.Margin


            return cl;
        }

        @Override
        protected void train(List<?> features, Object actual_label) {
            Margin margin = getMarginAndVariance(features, actual_label);

            float loss = loss(margin);
            if(loss > 0.f) {
                float var = margin.getVariance();
                float beta = 1.f / (var + r);
                float alpha = loss * beta;

                Object missed_label = margin.getMaxIncorrectLabel();
                update(features, actual_label, missed_label, alpha, beta);
            }
        }
View Full Code Here


        return cl;
    }

    @Override
    protected void train(List<?> features, Object actual_label) {
        Margin margin = getMarginAndVariance(features, actual_label, true);
        float gamma = getGamma(margin);

        if(gamma > 0.f) {// alpha = max(0, gamma)                  
            Object missed_label = margin.getMaxIncorrectLabel();
            update(features, gamma, actual_label, missed_label);
        }
    }
View Full Code Here

        return cl;
    }

    @Override
    protected void train(List<?> features, Object actual_label) {
        Margin margin = getMarginAndVariance(features, actual_label, true);
        float loss = loss(margin);

        if(loss > 0.f) {
            float alpha = getAlpha(margin);
            if(alpha == 0.f) {
                return;
            }
            float beta = getBeta(margin, alpha);
            if(beta == 0.f) {
                return;
            }

            Object missed_label = margin.getMaxIncorrectLabel();
            update(features, actual_label, missed_label, alpha, beta);
        }
    }
View Full Code Here

                    maxAnotherLabel = label;
                    maxAnotherScore = score;
                }
            }
        }
        return new Margin(correctScore, maxAnotherLabel, maxAnotherScore);
    }
View Full Code Here

        float maxAnotherScore = 0.f;
        float maxAnotherVariance = 0.f;

        if(nonZeroVariance && label2model.isEmpty()) {// for initial call
            float var = 2.f * calcVariance(features);
            return new Margin(correctScore, maxAnotherLabel, maxAnotherScore).variance(var);
        }

        for(Map.Entry<Object, PredictionModel> label2map : label2model.entrySet()) {// for each class
            Object label = label2map.getKey();
            PredictionModel model = label2map.getValue();
            PredictionResult predicted = calcScoreAndVariance(model, features);
            float score = predicted.getScore();

            if(label.equals(actual_label)) {
                correctScore = score;
                correctVariance = predicted.getVariance();
            } else {
                if(maxAnotherLabel == null || score > maxAnotherScore) {
                    maxAnotherLabel = label;
                    maxAnotherScore = score;
                    maxAnotherVariance = predicted.getVariance();
                }
            }
        }

        float var = correctVariance + maxAnotherVariance;
        return new Margin(correctScore, maxAnotherLabel, maxAnotherScore).variance(var);
    }
View Full Code Here

    @Override
    protected void train(List<?> features, Object actual_label) {
        assert (!features.isEmpty());

        Margin margin = getMargin(features, actual_label);
        float loss = loss(margin);

        if(loss > 0.f) { // & missed_label != null
            float sqnorm = squaredNorm(features);
            if(sqnorm == 0.f) {// avoid divide by zero
                return;
            }
            float coeff = eta(loss, sqnorm);
            Object missed_label = margin.getMaxIncorrectLabel();
            update(features, coeff, actual_label, missed_label);
        }
    }
View Full Code Here

        return cl;
    }

    @Override
    protected void train(List<?> features, Object actual_label) {
        Margin margin = getMarginAndVariance(features, actual_label);
        float m = margin.get();

        if(m >= 1.f) {
            return;
        }

        float var = margin.getVariance();
        float beta = 1.f / (var + r);
        float alpha = (1.f - m) * beta;

        Object missed_label = margin.getMaxIncorrectLabel();
        update(features, actual_label, missed_label, alpha, beta);
    }
View Full Code Here

TOP

Related Classes of hivemall.io.Margin

Copyright © 2018 www.massapicom. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.