use of org.dmg.pmml.Targets in project shifu by ShifuML.
the class NNPmmlModelCreator method createTargets.
public Targets createTargets() {
Targets targets = new Targets();
if (modelConfig.isClassification() && ModelTrainConf.MultipleClassification.NATIVE.equals(modelConfig.getTrain().getMultiClassifyMethod())) {
List<Target> targetList = createMultiClassTargets();
targets.addTargets(targetList.toArray(new Target[targetList.size()]));
} else {
Target target = new Target();
target.setOpType(OpType.CONTINUOUS);
target.setField(new FieldName(modelConfig.getTargetColumnName()));
List<TargetValue> targetValueList = new ArrayList<TargetValue>();
if (CollectionUtils.isNotEmpty(modelConfig.getPosTags())) {
for (String posTagValue : modelConfig.getPosTags()) {
TargetValue pos = new TargetValue();
pos.setValue(posTagValue);
pos.setDisplayValue("Positive");
targetValueList.add(pos);
}
}
if (CollectionUtils.isNotEmpty(modelConfig.getNegTags())) {
for (String negTagValue : modelConfig.getNegTags()) {
TargetValue neg = new TargetValue();
neg.setValue(negTagValue);
neg.setDisplayValue("Negative");
targetValueList.add(neg);
}
}
target.addTargetValues(targetValueList.toArray(new TargetValue[targetValueList.size()]));
targets.addTargets(target);
}
return targets;
}
use of org.dmg.pmml.Targets in project shifu by ShifuML.
the class TreeEnsemblePmmlCreator method createTargets.
protected Targets createTargets(ModelConfig modelConfig) {
Targets targets = new Targets();
Target target = new Target();
target.setOpType(OpType.CATEGORICAL);
target.setField(new FieldName(modelConfig.getTargetColumnName()));
List<TargetValue> targetValueList = new ArrayList<TargetValue>();
for (String posTagValue : modelConfig.getPosTags()) {
TargetValue pos = new TargetValue();
pos.setValue(posTagValue);
pos.setDisplayValue("Positive");
targetValueList.add(pos);
}
for (String negTagValue : modelConfig.getNegTags()) {
TargetValue neg = new TargetValue();
neg.setValue(negTagValue);
neg.setDisplayValue("Negative");
targetValueList.add(neg);
}
target.addTargetValues(targetValueList.toArray(new TargetValue[targetValueList.size()]));
targets.addTargets(target);
return targets;
}
use of org.dmg.pmml.Targets in project drools by kiegroup.
the class ModelUtilsTest method getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType.
@Test
public void getTargetFieldsWithTargetFieldsWithTargetsWithoutOptType() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
final String fieldName = "fieldName-" + i;
final DataField dataField = getDataField(fieldName, OpType.CATEGORICAL, DataType.STRING);
dataDictionary.addDataFields(dataField);
final MiningField miningField = getMiningField(fieldName, MiningField.UsageType.PREDICTED);
miningField.setOpType(OpType.CONTINUOUS);
miningSchema.addMiningFields(miningField);
final Target targetField = getTarget(fieldName, null);
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
List<KiePMMLNameOpType> retrieved = ModelUtils.getTargetFields(getFieldsFromDataDictionary(dataDictionary), model);
assertNotNull(retrieved);
assertEquals(miningSchema.getMiningFields().size(), retrieved.size());
retrieved.forEach(kiePMMLNameOpType -> {
Optional<MiningField> optionalMiningField = miningSchema.getMiningFields().stream().filter(fld -> kiePMMLNameOpType.getName().equals(fld.getName().getValue())).findFirst();
assertTrue(optionalMiningField.isPresent());
MiningField miningField = optionalMiningField.get();
OP_TYPE expected = OP_TYPE.byName(miningField.getOpType().value());
assertEquals(expected, kiePMMLNameOpType.getOpType());
});
}
use of org.dmg.pmml.Targets in project drools by kiegroup.
the class KiePMMLUtilTest method correctTargetFields.
@Test
public void correctTargetFields() {
final MiningField miningField = new MiningField(FieldName.create("FIELD_NAME"));
final Targets targets = new Targets();
final Target namedTarget = new Target();
String targetName = "TARGET_NAME";
namedTarget.setField(FieldName.create(targetName));
final Target unnamedTarget = new Target();
targets.addTargets(namedTarget, unnamedTarget);
KiePMMLUtil.correctTargetFields(miningField, targets);
assertEquals(targetName, namedTarget.getField().getValue());
assertEquals(miningField.getName(), unnamedTarget.getField());
}
use of org.dmg.pmml.Targets in project drools by kiegroup.
the class ModelUtilsTest method getOpTypeByTargetsNotFound.
@Test(expected = KiePMMLInternalException.class)
public void getOpTypeByTargetsNotFound() {
final Model model = new RegressionModel();
final DataDictionary dataDictionary = new DataDictionary();
final MiningSchema miningSchema = new MiningSchema();
final Targets targets = new Targets();
IntStream.range(0, 3).forEach(i -> {
String fieldName = "field" + i;
final DataField dataField = getRandomDataField();
dataField.setName(FieldName.create(fieldName));
dataDictionary.addDataFields(dataField);
final MiningField miningField = getRandomMiningField();
miningField.setName(dataField.getName());
miningSchema.addMiningFields(miningField);
final Target targetField = getRandomTarget();
targetField.setField(dataField.getName());
targets.addTargets(targetField);
});
model.setMiningSchema(miningSchema);
model.setTargets(targets);
ModelUtils.getOpType(getFieldsFromDataDictionary(dataDictionary), model, "NOT_EXISTING");
}
Aggregations