use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class SimpleRegressionTest method testRegression.
@Test
public void testRegression() throws Exception {
RuleUnitExecutor executor = createExecutor(source1);
PMMLRequestData request = new PMMLRequestData("123", "LinReg");
request.addRequestParam("fld1", 0.9);
request.addRequestParam("fld2", 0.3);
request.addRequestParam("fld3", "x");
PMML4Result resultHolder = new PMML4Result();
List<String> possiblePackages = calculatePossiblePackageNames("LinReg");
Class<? extends RuleUnit> unitClass = getStartingRuleUnit("RuleUnitIndicator", (InternalKnowledgeBase) kbase, possiblePackages);
assertNotNull(unitClass);
int x = executor.run(unitClass);
data.insert(request);
resultData.insert(resultHolder);
executor.run(unitClass);
assertEquals("OK", resultHolder.getResultCode());
assertNotNull(resultHolder.getResultValue("Fld4", null));
Double value = resultHolder.getResultValue("Fld4", "value", Double.class).orElse(null);
assertNotNull(value);
double chkVal = 0.5 + 5 * 0.9 * 0.9 + 2 * 0.3 - 3.0 + 0.4 * 0.9 * 0.3;
chkVal = 1.0 / (1.0 + Math.exp(-chkVal));
assertEquals(chkVal, value, 1e-6);
}
use of org.kie.api.pmml.PMMLRequestData in project drools by kiegroup.
the class MiningSegmentTransfer method getOutboundRequest.
public PMMLRequestData getOutboundRequest() {
if (outboundRequest == null) {
outboundRequest = new PMMLRequestData(this.correlationId);
outboundRequest.setSource("MiningSegmentTransfer:" + this.fromSegmentId + "-" + this.toSegmentId);
for (String requestField : requestFromResultMap.keySet()) {
String resultFieldName = requestFromResultMap.get(requestField);
Object resultFieldValue = getValueFromResult(resultFieldName);
if (resultFieldValue != null) {
outboundRequest.addRequestParam(requestField, resultFieldValue);
}
}
}
return outboundRequest;
}
Aggregations