Package org.drools.scorecards

Source Code of org.drools.scorecards.ScoringStrategiesTest

package org.drools.scorecards;

import org.dmg.pmml.pmml_4_1.descr.Extension;
import org.dmg.pmml.pmml_4_1.descr.PMML;
import org.dmg.pmml.pmml_4_1.descr.Scorecard;
import org.drools.pmml.pmml_4_1.extensions.AggregationStrategy;
import org.drools.scorecards.pmml.ScorecardPMMLExtensionNames;
import org.drools.scorecards.pmml.ScorecardPMMLUtils;
import org.junit.Before;
import org.junit.Test;
import org.kie.api.KieBase;
import org.kie.api.KieServices;
import org.kie.api.builder.KieBuilder;
import org.kie.api.builder.KieFileSystem;
import org.kie.api.builder.Message;
import org.kie.api.builder.Results;
import org.kie.api.definition.type.FactType;
import org.kie.api.io.ResourceType;
import org.kie.api.runtime.KieContainer;
import org.kie.api.runtime.StatelessKieSession;

import java.io.InputStream;

import static org.junit.Assert.*;
import static org.drools.scorecards.ScorecardCompiler.DrlType.INTERNAL_DECLARED_TYPES;

public class ScoringStrategiesTest {


    @Before
    public void setUp() throws Exception {
    }

    @Test
    public void testScoringExtension() throws Exception {
        PMML pmmlDocument;
        ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
        if (scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_scoring_strategies.xls")) ) {
            pmmlDocument = scorecardCompiler.getPMMLDocument();
            assertNotNull(pmmlDocument);
            String drl = scorecardCompiler.getDRL();
            assertNotNull(drl);
            for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()){
                if (serializable instanceof Scorecard){
                    Scorecard scorecard = (Scorecard)serializable;
                    assertEquals("Sample Score",scorecard.getModelName());
                    Extension extension = ScorecardPMMLUtils.getExtension(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_SCORING_STRATEGY);
                    assertNotNull(extension);
                    assertEquals( extension.getValue(), AggregationStrategy.AGGREGATE_SCORE.toString() );
                    return;
                }
            }
        }
        fail();
    }

    @Test
    public void testAggregate() throws Exception {

        double finalScore = executeAndFetchScore("scorecards");
        //age==10 (30), validLicense==FALSE (-1)
        assertEquals(29.0, finalScore, 0.0);
    }

    @Test
    public void testAverage() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_avg");
        //age==10 (30), validLicense==FALSE (-1)
        //count = 2
        assertEquals(14.5, finalScore, 0.0);
    }

    @Test
    public void testMinimum() throws Exception {
        double finalScore = executeAndFetchScore("scorecards_min");
        //age==10 (30), validLicense==FALSE (-1)
        assertEquals(-1.0, finalScore, 0.0);
    }

    @Test
    public void testMaximum() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_max");
        //age==10 (30), validLicense==FALSE (-1)
        assertEquals(30.0, finalScore, 0.0);
    }

    @Test
    public void testWeightedAggregate() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_aggregate");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        assertEquals(599.0, finalScore, 0.0);
    }

    @Test
    public void testWeightedAverage() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_avg");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        assertEquals(299.5, finalScore, 0.0);
    }

    @Test
    public void testWeightedMaximum() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_max");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        assertEquals(600.0, finalScore, 0.0);
    }

    @Test
    public void testWeightedMinimum() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_min");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        assertEquals(-1.0, finalScore, 0.0);
    }

    /* Tests with Initial Score */
    @Test
    public void testAggregateInitialScore() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_initial_score");
        //age==10 (30), validLicense==FALSE (-1)
        //initialScore = 100
        assertEquals(129.0, finalScore, 0.0);
    }

    @Test
    public void testAverageInitialScore() throws Exception {
        double finalScore = executeAndFetchScore("scorecards_avg_initial_score");
        //age==10 (30), validLicense==FALSE (-1)
        //count = 2
        //initialScore = 100
        assertEquals(114.5, finalScore, 0.0);
    }

    @Test
    public void testMinimumInitialScore() throws Exception {
        double finalScore = executeAndFetchScore("scorecards_min_initial_score");
        //age==10 (30), validLicense==FALSE (-1)
        //initialScore = 100
        assertEquals(99.0, finalScore, 0.0);
    }

    @Test
    public void testMaximumInitialScore() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_max_initial_score");
        //age==10 (30), validLicense==FALSE (-1)
        //initialScore = 100
        assertEquals(130.0, finalScore, 0.0);
    }

    @Test
    public void testWeightedAggregateInitialScore() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_aggregate_initial");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        //initialScore = 100
        assertEquals(699.0, finalScore, 0.0);
    }

    @Test
    public void testWeightedAverageInitialScore() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_avg_initial");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        //initialScore = 100
        assertEquals(399.5, finalScore, 0.0);
    }

    @Test
    public void testWeightedMaximumInitialScore() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_max_initial");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        //initialScore = 100
        assertEquals(700.0, finalScore, 0.0);
    }

    @Test
    public void testWeightedMinimumInitialScore() throws Exception {

        double finalScore = executeAndFetchScore("scorecards_w_min_initial");
        //age==10 (score=30, w=20), validLicense==FALSE (score=-1, w=1)
        //initialScore = 100
        assertEquals(99.0, finalScore, 0.0);
    }

    /* Internal functions */
    private double executeAndFetchScore(String sheetName) throws Exception {

        ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
        InputStream inputStream = PMMLDocumentTest.class.getResourceAsStream( "/scoremodel_scoring_strategies.xls" );
        boolean compileResult = scorecardCompiler.compileFromExcel(inputStream, sheetName);
        if (!compileResult) {
            for(ScorecardError error : scorecardCompiler.getScorecardParseErrors()){
                System.err.println("Scorecard Compiler Error :"+error.getErrorLocation()+"->"+error.getErrorMessage());
            }
            return -999999;
        }
        String drl = scorecardCompiler.getDRL();

        KieServices ks = KieServices.Factory.get();
        KieFileSystem kfs = ks.newKieFileSystem();
        kfs.write( ks.getResources().newByteArrayResource( drl.getBytes() )
                           .setSourcePath( "scoremodel_scoring_strategies.drl" )
                           .setResourceType( ResourceType.DRL ) );
        KieBuilder kieBuilder = ks.newKieBuilder( kfs );
        Results res = kieBuilder.buildAll().getResults();
        if ( res.hasMessages( Message.Level.ERROR ) ) {
            System.out.println( res.getMessages() );
        }
        assertEquals( 0, res.getMessages( Message.Level.ERROR ).size() );

        KieContainer kieContainer = ks.newKieContainer( kieBuilder.getKieModule().getReleaseId() );

        KieBase kbase = kieContainer.getKieBase();
        StatelessKieSession session = kbase.newStatelessKieSession();

        FactType scorecardType = kbase.getFactType( "org.drools.scorecards.example","SampleScore" );
        Object scorecard = scorecardType.newInstance();
        scorecardType.set(scorecard, "age", 10);
        session.execute(scorecard);
        return (Double) scorecardType.get( scorecard, "scorecard__calculatedScore" );
    }

}
TOP

Related Classes of org.drools.scorecards.ScoringStrategiesTest

TOP
Copyright © 2018 www.massapi.com. All rights reserved.
All source code are property of their respective owners. Java is a trademark of Sun Microsystems, Inc and owned by ORACLE Inc. Contact coftware#gmail.com.