use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class ScorecardReasonCodeTest method testReasonCodes.
@Test
public void testReasonCodes() throws Exception {
final ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
boolean compileResult = scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_reasoncodes.xls"));
if (!compileResult) {
assertErrors(scorecardCompiler);
}
final PMML pmmlDocument = scorecardCompiler.getPMMLDocument();
for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (serializable instanceof Scorecard) {
for (Object obj : ((Scorecard) serializable).getExtensionsAndCharacteristicsAndMiningSchemas()) {
if (obj instanceof Characteristics) {
Characteristics characteristics = (Characteristics) obj;
assertEquals(4, characteristics.getCharacteristics().size());
for (Characteristic characteristic : characteristics.getCharacteristics()) {
for (Attribute attribute : characteristic.getAttributes()) {
assertNotNull(attribute.getReasonCode());
}
}
return;
}
}
}
}
fail();
}
use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class ScoringStrategiesTest method testScoringExtension.
@Test
public void testScoringExtension() throws Exception {
PMML pmmlDocument;
ScorecardCompiler scorecardCompiler = new ScorecardCompiler(INTERNAL_DECLARED_TYPES);
if (scorecardCompiler.compileFromExcel(PMMLDocumentTest.class.getResourceAsStream("/scoremodel_scoring_strategies.xls"))) {
pmmlDocument = scorecardCompiler.getPMMLDocument();
assertNotNull(pmmlDocument);
String drl = scorecardCompiler.getDRL();
assertNotNull(drl);
for (Object serializable : pmmlDocument.getAssociationModelsAndBaselineModelsAndClusteringModels()) {
if (serializable instanceof Scorecard) {
Scorecard scorecard = (Scorecard) serializable;
assertEquals("Sample Score", scorecard.getModelName());
Extension extension = ScorecardPMMLUtils.getExtension(scorecard.getExtensionsAndCharacteristicsAndMiningSchemas(), ScorecardPMMLExtensionNames.SCORECARD_SCORING_STRATEGY);
assertNotNull(extension);
assertEquals(extension.getValue(), AggregationStrategy.AGGREGATE_SCORE.toString());
return;
}
}
}
fail();
}
use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class GuidedScoreCardDRLPersistence method marshal.
public static String marshal(final ScoreCardModel model) {
final PMML pmml = createPMMLDocument(model);
final StringBuilder sb = new StringBuilder();
// Package statement and Imports are appended by org.drools.scorecards.drl.AbstractDRLEmitter
// Build rules
sb.append(ScorecardCompiler.convertToDRL(pmml, ScorecardCompiler.DrlType.EXTERNAL_OBJECT_MODEL));
return sb.toString();
}
use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class PMMLGenerationTest method testNNGenration.
@Test
public void testNNGenration() {
PMML net = PMMLGeneratorUtils.generateSimpleNeuralNetwork(modelName, inputfieldNames, outputfieldNames, inputMeans, inputStds, outputMeans, outputStds, hiddenSize, weights);
assertNotNull(net);
ByteArrayOutputStream baos = new ByteArrayOutputStream();
assertTrue(PMMLGeneratorUtils.streamPMML(net, baos));
ByteArrayInputStream bais = new ByteArrayInputStream(baos.toByteArray());
PMML4Compiler compiler = new PMML4Compiler();
SchemaFactory sf = SchemaFactory.newInstance(XMLConstants.W3C_XML_SCHEMA_NS_URI);
try {
Schema schema = sf.newSchema(Thread.currentThread().getContextClassLoader().getResource(compiler.SCHEMA_PATH));
schema.newValidator().validate(new StreamSource(bais));
} catch (SAXException e) {
fail(e.getMessage());
} catch (IOException e) {
fail(e.getMessage());
}
PMML net2 = null;
try {
bais.reset();
JAXBContext ctx = JAXBContext.newInstance(PMML.class.getPackage().getName());
net2 = (PMML) ctx.createUnmarshaller().unmarshal(bais);
} catch (JAXBException e) {
e.printStackTrace();
}
assertNotNull(net2);
assertEquals(inputfieldNames.length + outputfieldNames.length, net2.getDataDictionary().getDataFields().size());
assertEquals(net.getDataDictionary().getDataFields().size(), net2.getDataDictionary().getDataFields().size());
NeuralNetwork n1 = (NeuralNetwork) net.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
NeuralNetwork n2 = (NeuralNetwork) net2.getAssociationModelsAndBaselineModelsAndClusteringModels().get(0);
assertEquals(n1.getExtensionsAndNeuralLayersAndNeuralInputs().size(), n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
assertEquals(6, n2.getExtensionsAndNeuralLayersAndNeuralInputs().size());
NeuralLayer l1 = (NeuralLayer) n1.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
NeuralLayer l2 = (NeuralLayer) n2.getExtensionsAndNeuralLayersAndNeuralInputs().get(3);
assertEquals(l1.getNeurons().get(4).getCons().get(2).getWeight(), l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
assertEquals(weights[(inputfieldNames.length + 1) * 4 + 3], l2.getNeurons().get(4).getCons().get(2).getWeight(), 1e-9);
}
use of org.dmg.pmml.pmml_4_2.descr.PMML in project drools by kiegroup.
the class PMML4Compiler method addMissingFieldDefinition.
private void addMissingFieldDefinition(PMML pmml, MiningSegmentation msm, MiningSegment seg) {
// get the list of models that may contain the field definition
List<PMML4Model> models = msm.getMiningSegments().stream().filter(s -> s != seg && s.getSegmentIndex() < seg.getSegmentIndex()).map(iseg -> {
return iseg.getModel();
}).collect(Collectors.toList());
seg.getModel().getMiningFields().stream().filter(mf -> !mf.isInDictionary()).forEach(pmf -> {
String fldName = pmf.getName();
boolean fieldAdded = false;
for (Iterator<PMML4Model> iter = models.iterator(); iter.hasNext() && !fieldAdded; ) {
PMML4Model mdl = iter.next();
PMMLOutputField outfield = mdl.findOutputField(fldName);
PMMLMiningField target = (outfield != null && outfield.getTargetField() != null) ? mdl.findMiningField(outfield.getTargetField()) : null;
if (outfield != null) {
DataField e = null;
if (outfield.getRawDataField() != null && outfield.getRawDataField().getDataType() != null) {
e = outfield.getRawDataField();
} else if (target != null) {
e = target.getRawDataField();
}
if (e != null) {
e.setName(fldName);
pmml.getDataDictionary().getDataFields().add(e);
BigInteger bi = pmml.getDataDictionary().getNumberOfFields();
pmml.getDataDictionary().setNumberOfFields(bi.add(BigInteger.ONE));
fieldAdded = true;
}
}
}
});
}
Aggregations