Search in sources :

Example 1 with Targets

use of org.dmg.pmml.Targets in project shifu by ShifuML.

the class NNPmmlModelCreator method createTargets.

public Targets createTargets() {
    Targets targets = new Targets();
    if (modelConfig.isClassification() && ModelTrainConf.MultipleClassification.NATIVE.equals(modelConfig.getTrain().getMultiClassifyMethod())) {
        List<Target> targetList = createMultiClassTargets();
        targets.addTargets(targetList.toArray(new Target[targetList.size()]));
    } else {
        Target target = new Target();
        target.setOpType(OpType.CONTINUOUS);
        target.setField(new FieldName(modelConfig.getTargetColumnName()));
        List<TargetValue> targetValueList = new ArrayList<TargetValue>();
        if (CollectionUtils.isNotEmpty(modelConfig.getPosTags())) {
            for (String posTagValue : modelConfig.getPosTags()) {
                TargetValue pos = new TargetValue();
                pos.setValue(posTagValue);
                pos.setDisplayValue("Positive");
                targetValueList.add(pos);
            }
        }
        if (CollectionUtils.isNotEmpty(modelConfig.getNegTags())) {
            for (String negTagValue : modelConfig.getNegTags()) {
                TargetValue neg = new TargetValue();
                neg.setValue(negTagValue);
                neg.setDisplayValue("Negative");
                targetValueList.add(neg);
            }
        }
        target.addTargetValues(targetValueList.toArray(new TargetValue[targetValueList.size()]));
        targets.addTargets(target);
    }
    return targets;
}
Also used : Target(org.dmg.pmml.Target) TargetValue(org.dmg.pmml.TargetValue) ArrayList(java.util.ArrayList) Targets(org.dmg.pmml.Targets) FieldName(org.dmg.pmml.FieldName)

Example 2 with Targets

use of org.dmg.pmml.Targets in project shifu by ShifuML.

the class TreeEnsemblePmmlCreator method createTargets.

protected Targets createTargets(ModelConfig modelConfig) {
    Targets targets = new Targets();
    Target target = new Target();
    target.setOpType(OpType.CATEGORICAL);
    target.setField(new FieldName(modelConfig.getTargetColumnName()));
    List<TargetValue> targetValueList = new ArrayList<TargetValue>();
    for (String posTagValue : modelConfig.getPosTags()) {
        TargetValue pos = new TargetValue();
        pos.setValue(posTagValue);
        pos.setDisplayValue("Positive");
        targetValueList.add(pos);
    }
    for (String negTagValue : modelConfig.getNegTags()) {
        TargetValue neg = new TargetValue();
        neg.setValue(negTagValue);
        neg.setDisplayValue("Negative");
        targetValueList.add(neg);
    }
    target.addTargetValues(targetValueList.toArray(new TargetValue[targetValueList.size()]));
    targets.addTargets(target);
    return targets;
}
Also used : Target(org.dmg.pmml.Target) TargetValue(org.dmg.pmml.TargetValue) ArrayList(java.util.ArrayList) Targets(org.dmg.pmml.Targets) FieldName(org.dmg.pmml.FieldName)

Example 3 with Targets

use of org.dmg.pmml.Targets in project drools by kiegroup.

the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType.

@Test
public void getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType() {
    final Model model = new RegressionModel();
    final DataDictionary dataDictionary = new DataDictionary();
    final MiningSchema miningSchema = new MiningSchema();
    final Targets targets = new Targets();
    IntStream.range(0, 3).forEach(i -> {
        final String fieldName = "fieldName-" + i;
        final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
        dataDictionary.addDataFields(dataField);
        final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
        miningField.setOpType(OpType.CONTINUOUS);
        miningSchema.addMiningFields(miningField);
        final Target targetField = getTarget(fieldName, null);
        targets.addTargets(targetField);
    });
    model.setMiningSchema(miningSchema);
    model.setTargets(targets);
    List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
    assertNotNull(retrieved);
    assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
    retrieved.forEach(kiePMMLNameOpType -> {
        Optional<MiningField> optionalMiningField = miningSchema.getMiningFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
        assertTrue(optionalMiningField.isPresent());
        MiningField miningField = optionalMiningField.get();
        OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
        assertEquals(expected, kiePMMLNameOpType.getOpType());
    });
}
Also used : RESULT_FEATURE(org.kie.pmml.api.enums.RESULT_FEATURE) ModelUtils.getPrefixedName(org.kie.pmml.compiler.api.utils.ModelUtils.getPrefixedName) Arrays(java.util.Arrays) Date(java.util.Date) Model(org.dmg.pmml.Model) MiningSchema(org.dmg.pmml.MiningSchema) OP_TYPE(org.kie.pmml.api.enums.OP_TYPE) Row(org.dmg.pmml.Row) FieldName(org.dmg.pmml.FieldName) FIELD_USAGE_TYPE(org.kie.pmml.api.enums.FIELD_USAGE_TYPE) OpType(org.dmg.pmml.OpType) Map(java.util.Map) InputCell(org.jpmml.model.inlinetable.InputCell) PMMLModelTestUtils.getDataTypes(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataTypes) PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) RegressionModel(org.dmg.pmml.regression.RegressionModel) Targets(org.dmg.pmml.Targets) DataType(org.dmg.pmml.DataType) PMMLModelTestUtils.getRandomRow(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomRow) Collectors(java.util.stream.Collectors) PMMLModelTestUtils.getParameterFields(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getParameterFields) DataField(org.dmg.pmml.DataField) List(java.util.List) Assert.assertFalse(org.junit.Assert.assertFalse) Optional(java.util.Optional) PMMLModelTestUtils.getArray(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray) PMMLModelTestUtils.getRandomDataType(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataType) ParameterField(org.dmg.pmml.ParameterField) IntStream(java.util.stream.IntStream) OutputField(org.dmg.pmml.OutputField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) LinkedHashMap(java.util.LinkedHashMap) Field(org.dmg.pmml.Field) DerivedField(org.dmg.pmml.DerivedField) KiePMMLInternalException(org.kie.pmml.api.exceptions.KiePMMLInternalException) KiePMMLNameOpType(org.kie.pmml.commons.model.tuples.KiePMMLNameOpType) OutputCell(org.jpmml.model.inlinetable.OutputCell) MiningField(org.dmg.pmml.MiningField) Iterator(java.util.Iterator) Assert.assertNotNull(org.junit.Assert.assertNotNull) Assert.assertTrue(org.junit.Assert.assertTrue) DataDictionary(org.dmg.pmml.DataDictionary) Test(org.junit.Test) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) Target(org.dmg.pmml.Target) DATA_TYPE(org.kie.pmml.api.enums.DATA_TYPE) Array(org.dmg.pmml.Array) PMMLModelTestUtils.getRandomOutputField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomOutputField) PMMLModelTestUtils.getRandomTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget) Assert.assertNull(org.junit.Assert.assertNull) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) PMMLModelTestUtils.getTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget) PMMLModelTestUtils.getRandomRowWithCells(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomRowWithCells) Assert.assertEquals(org.junit.Assert.assertEquals) PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) MiningField(org.dmg.pmml.MiningField) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) Targets(org.dmg.pmml.Targets) OP_TYPE(org.kie.pmml.api.enums.OP_TYPE) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) RegressionModel(org.dmg.pmml.regression.RegressionModel) KiePMMLNameOpType(org.kie.pmml.commons.model.tuples.KiePMMLNameOpType) Target(org.dmg.pmml.Target) PMMLModelTestUtils.getRandomTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget) PMMLModelTestUtils.getTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget) MiningSchema(org.dmg.pmml.MiningSchema) DataField(org.dmg.pmml.DataField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) Model(org.dmg.pmml.Model) RegressionModel(org.dmg.pmml.regression.RegressionModel) Test(org.junit.Test)

Example 4 with Targets

use of org.dmg.pmml.Targets in project drools by kiegroup.

the class KiePMMLUtilTest method correctTargetFields.

@Test
public void correctTargetFields() {
    final MiningField miningField = new MiningField(FieldName.create("FIELD_NAME"));
    final Targets targets = new Targets();
    final Target namedTarget = new Target();
    String targetName = "TARGET_NAME";
    namedTarget.setField(FieldName.create(targetName));
    final Target unnamedTarget = new Target();
    targets.addTargets(namedTarget, unnamedTarget);
    KiePMMLUtil.correctTargetFields(miningField, targets);
    assertEquals(targetName, namedTarget.getField().getValue());
    assertEquals(miningField.getName(), unnamedTarget.getField());
}
Also used : MiningField(org.dmg.pmml.MiningField) Target(org.dmg.pmml.Target) Targets(org.dmg.pmml.Targets) Test(org.junit.Test)

Example 5 with Targets

use of org.dmg.pmml.Targets in project drools by kiegroup.

the class ModelUtilsTest method getOpTypeByTargetsNotFound.

@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByTargetsNotFound() {
    final Model model = new RegressionModel();
    final DataDictionary dataDictionary = new DataDictionary();
    final MiningSchema miningSchema = new MiningSchema();
    final Targets targets = new Targets();
    IntStream.range(0, 3).forEach(i -> {
        String fieldName = "field" + i;
        final DataField dataField = getRandomDataField();
        dataField.setName(FieldName.create(fieldName));
        dataDictionary.addDataFields(dataField);
        final MiningField miningField = getRandomMiningField();
        miningField.setName(dataField.getName());
        miningSchema.addMiningFields(miningField);
        final Target targetField = getRandomTarget();
        targetField.setField(dataField.getName());
        targets.addTargets(targetField);
    });
    model.setMiningSchema(miningSchema);
    model.setTargets(targets);
    ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
Also used : PMMLModelTestUtils.getMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField) MiningField(org.dmg.pmml.MiningField) PMMLModelTestUtils.getRandomMiningField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField) Target(org.dmg.pmml.Target) PMMLModelTestUtils.getRandomTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget) PMMLModelTestUtils.getTarget(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget) MiningSchema(org.dmg.pmml.MiningSchema) DataField(org.dmg.pmml.DataField) PMMLModelTestUtils.getDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField) PMMLModelTestUtils.getRandomDataField(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField) Model(org.dmg.pmml.Model) RegressionModel(org.dmg.pmml.regression.RegressionModel) Targets(org.dmg.pmml.Targets) DataDictionary(org.dmg.pmml.DataDictionary) CommonTestingUtils.getFieldsFromDataDictionary(org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary) RegressionModel(org.dmg.pmml.regression.RegressionModel) Test(org.junit.Test)

Aggregations

Target (org.dmg.pmml.Target)10 Targets (org.dmg.pmml.Targets)10 Test (org.junit.Test)7 MiningField (org.dmg.pmml.MiningField)6 ArrayList (java.util.ArrayList)5 DataDictionary (org.dmg.pmml.DataDictionary)5 DataField (org.dmg.pmml.DataField)5 FieldName (org.dmg.pmml.FieldName)5 MiningSchema (org.dmg.pmml.MiningSchema)5 Model (org.dmg.pmml.Model)5 RegressionModel (org.dmg.pmml.regression.RegressionModel)5 PMMLModelTestUtils.getRandomTarget (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomTarget)5 PMMLModelTestUtils.getTarget (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getTarget)5 OP_TYPE (org.kie.pmml.api.enums.OP_TYPE)4 CommonTestingUtils.getFieldsFromDataDictionary (org.kie.pmml.compiler.api.CommonTestingUtils.getFieldsFromDataDictionary)4 PMMLModelTestUtils.getDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getDataField)4 PMMLModelTestUtils.getMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getMiningField)4 PMMLModelTestUtils.getRandomDataField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomDataField)4 PMMLModelTestUtils.getRandomMiningField (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getRandomMiningField)4 TargetValue (org.dmg.pmml.TargetValue)3