use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.
the class KiePMMLLocalTransformationsFactoryTest method getKiePMMLTransformationDictionaryVariableDeclaration.
@Test
public void getKiePMMLTransformationDictionaryVariableDeclaration() throws IOException {
LocalTransformations localTransformations = new LocalTransformations();
localTransformations.addDerivedFields(getDerivedFields());
BlockStmt retrieved = KiePMMLLocalTransformationsFactory.getKiePMMLLocalTransformationsVariableDeclaration(localTransformations);
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(text);
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(KiePMMLConstant.class, KiePMMLApply.class, KiePMMLDerivedField.class, KiePMMLLocalTransformations.class, Arrays.class, Collections.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.LocalTransformations in project drools by kiegroup.
the class KiePMMLLocalTransformationsInstanceFactoryTest method getKiePMMLLocalTransformations.
@Test
public void getKiePMMLLocalTransformations() {
final LocalTransformations toConvert = getRandomLocalTransformations();
KiePMMLLocalTransformations retrieved = KiePMMLLocalTransformationsInstanceFactory.getKiePMMLLocalTransformations(toConvert, Collections.emptyList());
assertNotNull(retrieved);
List<DerivedField> derivedFields = toConvert.getDerivedFields();
List<KiePMMLDerivedField> derivedFieldsToVerify = retrieved.getDerivedFields();
assertEquals(derivedFields.size(), derivedFieldsToVerify.size());
derivedFields.forEach(derivedFieldSource -> {
Optional<KiePMMLDerivedField> derivedFieldToVerify = derivedFieldsToVerify.stream().filter(param -> param.getName().equals(derivedFieldSource.getName().getValue())).findFirst();
assertTrue(derivedFieldToVerify.isPresent());
commonVerifyKiePMMLDerivedField(derivedFieldToVerify.get(), derivedFieldSource);
});
}
use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.
the class PMMLLRModelBuilder method adaptMLModelToPMML.
public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
RegressionTable table = new RegressionTable();
table.setIntercept(lr.getBias());
LocalTransformations lt = pmmlModel.getLocalTransformations();
List<DerivedField> df = lt.getDerivedFields();
HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
for (DerivedField dField : df) {
// Apply z-scale normalization on numerical variables
if (dField.getExpression() instanceof NormContinuous) {
miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
} else // Apply bin map on categorical variables
if (dField.getExpression() instanceof MapValues) {
miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
} else if (dField.getExpression() instanceof Discretize) {
miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
}
}
List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
int index = 0;
for (int i = 0; i < miningList.size(); i++) {
MiningField mField = miningList.get(i);
if (mField.getUsageType() != UsageType.ACTIVE)
continue;
FieldName mFieldName = mField.getName();
FieldName fName = mFieldName;
while (miningTransformMap.containsKey(fName)) {
fName = miningTransformMap.get(fName);
}
NumericPredictor np = new NumericPredictor();
np.setName(fName);
np.setCoefficient(lr.getWeights()[index++]);
table.addNumericPredictors(np);
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
use of org.dmg.pmml.LocalTransformations in project shifu by ShifuML.
the class ZscoreLocalTransformCreator method build.
@Override
public LocalTransformations build(BasicML basicML) {
LocalTransformations localTransformations = new LocalTransformations();
if (basicML instanceof BasicFloatNetwork) {
BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
Set<Integer> featureSet = bfn.getFeatureSet();
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(config.getColumnNum()))) {
double cutoff = modelConfig.getNormalizeStdDevCutOff();
List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
}
}
} else {
for (ColumnConfig config : columnConfigList) {
if (config.isFinalSelect()) {
double cutoff = modelConfig.getNormalizeStdDevCutOff();
List<DerivedField> deriviedFields = config.isCategorical() ? createCategoricalDerivedField(config, cutoff, modelConfig.getNormalizeType()) : createNumericalDerivedField(config, cutoff, modelConfig.getNormalizeType());
localTransformations.addDerivedFields(deriviedFields.toArray(new DerivedField[deriviedFields.size()]));
}
}
}
return localTransformations;
}
Aggregations