Search in sources :

Example 6 with NumericPredictor

use of org.dmg.pmml.regression.NumericPredictor in project drools by kiegroup.

the class PMMLModelTestUtils method getRegressionTable.

public static RegressionTable getRegressionTable(List<CategoricalPredictor> categoricalPredictors, List<NumericPredictor> numericPredictors, List<PredictorTerm> predictorTerms, double intercept, Object targetCategory) {
    RegressionTable toReturn = new RegressionTable();
    toReturn.setIntercept(intercept);
    toReturn.setTargetCategory(targetCategory);
    toReturn.addCategoricalPredictors(categoricalPredictors.toArray(new CategoricalPredictor[0]));
    toReturn.addNumericPredictors(numericPredictors.toArray(new NumericPredictor[0]));
    toReturn.addPredictorTerms(predictorTerms.toArray(new PredictorTerm[0]));
    return toReturn;
}
Also used : CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PredictorTerm(org.dmg.pmml.regression.PredictorTerm) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Example 7 with NumericPredictor

use of org.dmg.pmml.regression.NumericPredictor in project drools by kiegroup.

the class KiePMMLRegressionModelFactoryTest method setup.

@BeforeClass
public static void setup() {
    Random random = new Random();
    Set<String> fieldNames = new HashSet<>();
    regressionTables = IntStream.range(0, 3).mapToObj(i -> {
        List<CategoricalPredictor> categoricalPredictors = new ArrayList<>();
        List<NumericPredictor> numericPredictors = new ArrayList<>();
        List<PredictorTerm> predictorTerms = new ArrayList<>();
        IntStream.range(0, 3).forEach(j -> {
            String catFieldName = "CatPred-" + j;
            String numFieldName = "NumPred-" + j;
            categoricalPredictors.add(getCategoricalPredictor(catFieldName, random.nextDouble(), random.nextDouble()));
            numericPredictors.add(getNumericPredictor(numFieldName, random.nextInt(), random.nextDouble()));
            predictorTerms.add(getPredictorTerm("PredTerm-" + j, random.nextDouble(), Arrays.asList(catFieldName, numFieldName)));
            fieldNames.add(catFieldName);
            fieldNames.add(numFieldName);
        });
        return getRegressionTable(categoricalPredictors, numericPredictors, predictorTerms, tableIntercept + random.nextDouble(), tableTargetCategory + "-" + i);
    }).collect(Collectors.toList());
    dataFields = new ArrayList<>();
    miningFields = new ArrayList<>();
    fieldNames.forEach(fieldName -> {
        dataFields.add(getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING));
        miningFields.add(getMiningField(fieldName, MiningField.UsageType.ACTIVE));
    });
    targetMiningField = miningFields.get(0);
    targetMiningField.setUsageType(MiningField.UsageType.TARGET);
    dataDictionary = getDataDictionary(dataFields);
    transformationDictionary = new TransformationDictionary();
    miningSchema = getMiningSchema(miningFields);
    regressionModel = getRegressionModel(modelName, MiningFunction.REGRESSION, miningSchema, regressionTables);
    COMPILATION_UNIT = getFromFileName(KIE_PMML_REGRESSION_MODEL_TEMPLATE_JAVA);
    MODEL_TEMPLATE = COMPILATION_UNIT.getClassByName(KIE_PMML_REGRESSION_MODEL_TEMPLATE).get();
    pmml = new PMML();
    pmml.setDataDictionary(dataDictionary);
    pmml.setTransformationDictionary(transformationDictionary);
    pmml.addModels(regressionModel);
}
Also used : PredictorTerm(org.dmg.pmml.regression.PredictorTerm) PMMLModelTestUtils.getPredictorTerm(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm) TransformationDictionary(org.dmg.pmml.TransformationDictionary) ArrayList(java.util.ArrayList) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) PMMLModelTestUtils.getNumericPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor) CategoricalPredictor(org.dmg.pmml.regression.CategoricalPredictor) PMMLModelTestUtils.getCategoricalPredictor(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor) Random(java.util.Random) PMML(org.dmg.pmml.PMML) HashSet(java.util.HashSet) BeforeClass(org.junit.BeforeClass)

Example 8 with NumericPredictor

use of org.dmg.pmml.regression.NumericPredictor in project drools by kiegroup.

the class KiePMMLRegressionTableFactoryTest method getNumericPredictorExpressionWithExponent.

@Test
public void getNumericPredictorExpressionWithExponent() throws IOException {
    String predictorName = "predictorName";
    int exponent = 2;
    double coefficient = 1.23;
    NumericPredictor numericPredictor = PMMLModelTestUtils.getNumericPredictor(predictorName, exponent, coefficient);
    CastExpr retrieved = KiePMMLRegressionTableFactory.getNumericPredictorExpression(numericPredictor);
    String text = getFileContent(TEST_01_SOURCE);
    Expression expected = JavaParserUtils.parseExpression(String.format(text, coefficient, exponent));
    assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
Also used : Expression(com.github.javaparser.ast.expr.Expression) CastExpr(com.github.javaparser.ast.expr.CastExpr) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) Test(org.junit.Test)

Example 9 with NumericPredictor

use of org.dmg.pmml.regression.NumericPredictor in project shifu by ShifuML.

the class PMMLAdapterCommonUtil method getRegressionTable.

/**
 * Generate Regression Table based on the weight list, intercept and partial
 * PMML model
 *
 * @param weights
 *            weight list for the Regression Table
 * @param intercept
 *            the intercept
 * @param pmmlModel
 *            partial PMMl model
 * @return regression model instance
 */
public static RegressionModel getRegressionTable(final double[] weights, final double intercept, RegressionModel pmmlModel) {
    RegressionTable table = new RegressionTable();
    MiningSchema schema = pmmlModel.getMiningSchema();
    // TODO may not need target field in LRModel
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    List<String> outputFields = getSchemaFieldViaUsageType(schema, UsageType.TARGET);
    // TODO only one outputField, what if we have more than one outputField
    pmmlModel.setTargetFieldName(new FieldName(outputFields.get(0)));
    table.setTargetCategory(outputFields.get(0));
    List<String> activeFields = getSchemaFieldViaUsageType(schema, UsageType.ACTIVE);
    int index = 0;
    for (DerivedField dField : pmmlModel.getLocalTransformations().getDerivedFields()) {
        Expression expression = dField.getExpression();
        if (expression instanceof NormContinuous) {
            NormContinuous norm = (NormContinuous) expression;
            if (activeFields.contains(norm.getField().getValue()))
                table.addNumericPredictors(new NumericPredictor(dField.getName(), weights[index++]));
        }
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningSchema(org.dmg.pmml.MiningSchema) Expression(org.dmg.pmml.Expression) FieldName(org.dmg.pmml.FieldName) DerivedField(org.dmg.pmml.DerivedField) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable)

Example 10 with NumericPredictor

use of org.dmg.pmml.regression.NumericPredictor in project shifu by ShifuML.

the class PMMLLRModelBuilder method adaptMLModelToPMML.

public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
    pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
    pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
    RegressionTable table = new RegressionTable();
    table.setIntercept(lr.getBias());
    LocalTransformations lt = pmmlModel.getLocalTransformations();
    List<DerivedField> df = lt.getDerivedFields();
    HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
    for (DerivedField dField : df) {
        // Apply z-scale normalization on numerical variables
        if (dField.getExpression() instanceof NormContinuous) {
            miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
        } else // Apply bin map on categorical variables
        if (dField.getExpression() instanceof MapValues) {
            miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
        } else if (dField.getExpression() instanceof Discretize) {
            miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
        }
    }
    List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
    int index = 0;
    for (int i = 0; i < miningList.size(); i++) {
        MiningField mField = miningList.get(i);
        if (mField.getUsageType() != UsageType.ACTIVE)
            continue;
        FieldName mFieldName = mField.getName();
        FieldName fName = mFieldName;
        while (miningTransformMap.containsKey(fName)) {
            fName = miningTransformMap.get(fName);
        }
        NumericPredictor np = new NumericPredictor();
        np.setName(fName);
        np.setCoefficient(lr.getWeights()[index++]);
        table.addNumericPredictors(np);
    }
    pmmlModel.addRegressionTables(table);
    return pmmlModel;
}
Also used : NormContinuous(org.dmg.pmml.NormContinuous) MiningField(org.dmg.pmml.MiningField) HashMap(java.util.HashMap) NumericPredictor(org.dmg.pmml.regression.NumericPredictor) RegressionTable(org.dmg.pmml.regression.RegressionTable) LocalTransformations(org.dmg.pmml.LocalTransformations) MapValues(org.dmg.pmml.MapValues) Discretize(org.dmg.pmml.Discretize) DerivedField(org.dmg.pmml.DerivedField) FieldName(org.dmg.pmml.FieldName)

Aggregations

NumericPredictor (org.dmg.pmml.regression.NumericPredictor)10 Test (org.junit.Test)4 CategoricalPredictor (org.dmg.pmml.regression.CategoricalPredictor)3 PredictorTerm (org.dmg.pmml.regression.PredictorTerm)3 RegressionTable (org.dmg.pmml.regression.RegressionTable)3 CastExpr (com.github.javaparser.ast.expr.CastExpr)2 Expression (com.github.javaparser.ast.expr.Expression)2 DerivedField (org.dmg.pmml.DerivedField)2 FieldName (org.dmg.pmml.FieldName)2 NormContinuous (org.dmg.pmml.NormContinuous)2 PMMLModelTestUtils.getCategoricalPredictor (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getCategoricalPredictor)2 PMMLModelTestUtils.getNumericPredictor (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getNumericPredictor)2 PMMLModelTestUtils.getPredictorTerm (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getPredictorTerm)2 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 Random (java.util.Random)1 Discretize (org.dmg.pmml.Discretize)1 Expression (org.dmg.pmml.Expression)1 LocalTransformations (org.dmg.pmml.LocalTransformations)1