use of org.dmg.pmml.LinearNorm in project drools by kiegroup.
the class KiePMMLNormContinuousFactoryTest method getNormContinuousVariableDeclaration.
@Test
public void getNormContinuousVariableDeclaration() throws IOException {
String variableName = "variableName";
NormContinuous normContinuous = getRandomNormContinuous();
List<LinearNorm> linearNorms = normContinuous.getLinearNorms();
BlockStmt retrieved = KiePMMLNormContinuousFactory.getNormContinuousVariableDeclaration(variableName, normContinuous);
String outlierString = OUTLIER_TREATMENT_METHOD.class.getName() + "." + OUTLIER_TREATMENT_METHOD.byName(normContinuous.getOutliers().value()).name();
String text = getFileContent(TEST_01_SOURCE);
Statement expected = JavaParserUtils.parseBlock(String.format(text, variableName, normContinuous.getField().getValue(), linearNorms.get(0).getOrig(), linearNorms.get(0).getNorm(), linearNorms.get(1).getOrig(), linearNorms.get(1).getNorm(), linearNorms.get(2).getOrig(), linearNorms.get(2).getNorm(), outlierString, normContinuous.getMapMissingTo()));
assertTrue(JavaParserUtils.equalsNode(expected, retrieved));
List<Class<?>> imports = Arrays.asList(Arrays.class, Collections.class, KiePMMLLinearNorm.class, KiePMMLNormContinuous.class, OUTLIER_TREATMENT_METHOD.class);
commonValidateCompilationWithImports(retrieved, imports);
}
use of org.dmg.pmml.LinearNorm in project drools by kiegroup.
the class PMMLModelTestUtils method getRandomLinearNorm.
public static LinearNorm getRandomLinearNorm() {
Random random = new Random();
double orig = random.nextInt(100) / 10;
double norm = random.nextInt(100) / 10;
return new LinearNorm(orig, norm);
}
use of org.dmg.pmml.LinearNorm in project shifu by ShifuML.
the class WoeZscoreLocalTransformCreator method createCategoricalDerivedField.
/**
* Create @DerivedField for categorical variable
*
* @param config - ColumnConfig for categorical variable
* @param cutoff - cutoff for normalization
* @return DerivedField for variable
*/
@Override
protected List<DerivedField> createCategoricalDerivedField(ColumnConfig config, double cutoff, ModelNormalizeConf.NormType normType) {
List<DerivedField> derivedFields = new ArrayList<DerivedField>();
DerivedField derivedField = super.createCategoricalDerivedField(config, cutoff, ModelNormalizeConf.NormType.WOE).get(0);
derivedFields.add(derivedField);
double[] meanAndStdDev = Normalizer.calculateWoeMeanAndStdDev(config, isWeightedNorm);
// added capping logic to linearNorm
LinearNorm from = new LinearNorm().setOrig(meanAndStdDev[0] - meanAndStdDev[1] * cutoff).setNorm(-cutoff);
LinearNorm to = new LinearNorm().setOrig(meanAndStdDev[0] + meanAndStdDev[1] * cutoff).setNorm(cutoff);
NormContinuous normContinuous = new NormContinuous();
normContinuous.setField(FieldName.create(derivedField.getName().getValue()));
normContinuous.addLinearNorms(from, to);
normContinuous.setMapMissingTo(0.0);
normContinuous.setOutliers(OutlierTreatmentMethod.AS_EXTREME_VALUES);
// derived field name is consisted of FieldName and "_zscl"
derivedFields.add(new DerivedField(OpType.CONTINUOUS, DataType.DOUBLE).setName(FieldName.create(genPmmlColumnName(NormalUtils.getSimpleColumnName(config.getColumnName()), normType))).setExpression(normContinuous));
return derivedFields;
}
Aggregations