use of org.dmg.pmml.LinearNorm in project shifu by ShifuML.
the class WoeZscoreLocalTransformCreator method createNumericalDerivedField.
/**
* Create @DerivedField for numerical variable
*
* @param config - ColumnConfig for numerical variable
* @param cutoff - cutoff of normalization
* @return DerivedField for variable
*/
@Override
protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
List<DerivedField> derivedFields = new ArrayList<DerivedField>();
DerivedField derivedField = super.createNumericalDerivedField(config, cutoff, ModelNormalizeConf.NormType.WOE).get(0);
derivedFields.add(derivedField);
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(config, isWeightedNorm);
// added capping logic to linearNorm
LinearNorm from = new LinearNorm().setOrig(meanAndStdDev[0] - meanAndStdDev[1] * cutoff).setNorm(-cutoff);
LinearNorm to = new LinearNorm().setOrig(meanAndStdDev[0] + meanAndStdDev[1] * cutoff).setNorm(cutoff);
NormContinuous normContinuous = new NormContinuous();
normContinuous.setField(FieldName.create(derivedField.getName().getValue()));
normContinuous.addLinearNorms(from, to);
normContinuous.setMapMissingTo(0.0);
normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES);
// derived field name is consisted of FieldName and "_zscl"
derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(normContinuous));
return derivedFields;
}
use of org.dmg.pmml.LinearNorm in project shifu by ShifuML.
the class ZscoreLocalTransformCreator method createNumericalDerivedField.
/**
* Create @DerivedField for numerical variable
*
* @param config
* - ColumnConfig for numerical variable
* @param cutoff
* - cutoff of normalization
* @param normType
* - the normalization method that is used to generate DerivedField
* @return DerivedField for variable
*/
protected List<DerivedField> createNumericalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
// added capping logic to linearNorm
LinearNorm from = new LinearNorm().setOrig(config.getMean() - config.getStdDev() * cutoff).setNorm(-cutoff);
LinearNorm to = new LinearNorm().setOrig(config.getMean() + config.getStdDev() * cutoff).setNorm(cutoff);
NormContinuous normContinuous = new NormContinuous();
normContinuous.setField(FieldName.create(NormalUtils.getSimpleColumnName(config, columnConfigList, segmentExpansions, datasetHeaders)));
normContinuous.addLinearNorms(from, to);
normContinuous.setMapMissingTo(0.0);
normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES);
// derived field name is consisted of FieldName and "_zscl"
List<DerivedField> derivedFields = new ArrayList<DerivedField>();
derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(normContinuous));
return derivedFields;
}
use of org.dmg.pmml.LinearNorm in project drools by kiegroup.
the class KiePMMLNormContinuousFactory method getNormContinuousVariableDeclaration.
static BlockStmt getNormContinuousVariableDeclaration(final String variableName, final NormContinuous normContinuous) {
final MethodDeclaration methodDeclaration = NORMCONTINUOUS_TEMPLATE.getMethodsByName(GETKIEPMMLNORMCONTINUOUS).get(0).clone();
final BlockStmt toReturn = methodDeclaration.getBody().orElseThrow(() -> new KiePMMLException(String.format(MISSING_BODY_TEMPLATE, methodDeclaration)));
final VariableDeclarator variableDeclarator = getVariableDeclarator(toReturn, NORM_CONTINUOUS).orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_IN_BODY, NORM_CONTINUOUS, toReturn)));
variableDeclarator.setName(variableName);
final ObjectCreationExpr objectCreationExpr = variableDeclarator.getInitializer().orElseThrow(() -> new KiePMMLException(String.format(MISSING_VARIABLE_INITIALIZER_TEMPLATE, NORM_CONTINUOUS, toReturn))).asObjectCreationExpr();
final StringLiteralExpr nameExpr = new StringLiteralExpr(normContinuous.getField().getValue());
final OUTLIER_TREATMENT_METHOD outlierTreatmentMethod = OUTLIER_TREATMENT_METHOD.byName(normContinuous.getOutliers().value());
final NameExpr outlierTreatmentMethodExpr = new NameExpr(OUTLIER_TREATMENT_METHOD.class.getName() + "." + outlierTreatmentMethod.name());
NodeList<Expression> arguments = new NodeList<>();
int counter = 0;
for (LinearNorm linearNorm : normContinuous.getLinearNorms()) {
arguments.add(getNewKiePMMLLinearNormExpression(linearNorm, "LinearNorm-" + counter));
}
final Expression mapMissingToExpr = getExpressionForObject(normContinuous.getMapMissingTo());
objectCreationExpr.getArguments().set(0, nameExpr);
objectCreationExpr.getArguments().get(2).asMethodCallExpr().setArguments(arguments);
objectCreationExpr.getArguments().set(3, outlierTreatmentMethodExpr);
objectCreationExpr.getArguments().set(4, mapMissingToExpr);
return toReturn;
}
use of org.dmg.pmml.LinearNorm in project drools by kiegroup.
the class KiePMMLNormContinuousFactoryTest method getNewKiePMMLLinearNormExpression.
@Test
public void getNewKiePMMLLinearNormExpression() throws IOException {
String name = "name";
LinearNorm linearNorm = getRandomLinearNorm();
Expression retrieved = KiePMMLNormContinuousFactory.getNewKiePMMLLinearNormExpression(linearNorm, name);
String text = getFileContent(TEST_02_SOURCE);
Expression expected = JavaParserUtils.parseExpression(String.format(text, name, linearNorm.getOrig(), linearNorm.getNorm()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
}
use of org.dmg.pmml.LinearNorm in project drools by kiegroup.
the class KiePMMLNormContinuousFactoryTest method getNormContinuousVariableDeclaration.
@Test
public void getNormContinuousVariableDeclaration() throws IOException {
String variableName = "variableName";
NormContinuous normContinuous = getRandomNormContinuous();
List<LinearNorm> linearNorms = normContinuous.getLinearNorms();
BlockStmt retrieved = KiePMMLNormContinuousFactory.getNormContinuousVariableDeclaration(variableName, normContinuous);
String outlierString = OUTLIER_TREATMENT_METHOD.class.getName() + "." + OUTLIER_TREATMENT_METHOD.byName(normContinuous.getOutliers().value()).name();
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, normContinuous.getField().getValue(), linearNorms.get(0).getOrig(), linearNorms.get(0).getNorm(), linearNorms.get(1).getOrig(), linearNorms.get(1).getNorm(), linearNorms.get(2).getOrig(), linearNorms.get(2).getNorm(), outlierString, normContinuous.getMapMissingTo()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(Arrays.class, Collections.class, KiePMMLLinearNorm.class, KiePMMLNormContinuous.class, OUTLIER_TREATMENT_METHOD.class);
commonValidateCompilationWithImports(retrieved, imports);
}
Aggregations