use of org.dmg.pmml.RuleSetModelDocument.RuleSetModel in project knime-core by knime.
the class PMMLModelWrapper method getSegmentContent.
/**
* Returns the content of a segment as a model wrapper.
* @param s The segment
* @return Returns a wrapper around the model
*/
public static PMMLModelWrapper getSegmentContent(final Segment s) {
TreeModel treemodel = s.getTreeModel();
if (treemodel != null) {
return new PMMLTreeModelWrapper(treemodel);
}
RegressionModel regrmodel = s.getRegressionModel();
if (regrmodel != null) {
return new PMMLRegressionModelWrapper(regrmodel);
}
GeneralRegressionModel genregrmodel = s.getGeneralRegressionModel();
if (genregrmodel != null) {
return new PMMLGeneralRegressionModelWrapper(genregrmodel);
}
ClusteringModel clustmodel = s.getClusteringModel();
if (clustmodel != null) {
return new PMMLClusteringModelWrapper(clustmodel);
}
NaiveBayesModel nbmodel = s.getNaiveBayesModel();
if (nbmodel != null) {
return new PMMLNaiveBayesModelWrapper(nbmodel);
}
NeuralNetwork nn = s.getNeuralNetwork();
if (nn != null) {
return new PMMLNeuralNetworkWrapper(nn);
}
RuleSetModel rsmodel = s.getRuleSetModel();
if (rsmodel != null) {
return new PMMLRuleSetModelWrapper(rsmodel);
}
SupportVectorMachineModel svmmodel = s.getSupportVectorMachineModel();
if (svmmodel != null) {
return new PMMLSupportVectorMachineModelWrapper(svmmodel);
}
return null;
}
use of org.dmg.pmml.RuleSetModelDocument.RuleSetModel in project knime-core by knime.
the class PMMLUtils method getFirstMiningSchema.
/**
* Retrieves the mining schema of the first model of a specific type.
*
* @param pmmlDoc the PMML document to extract the mining schema from
* @param type the type of the model
* @return the mining schema of the first model of the given type or null if
* there is no model of the given type contained in the pmmlDoc
*/
public static MiningSchema getFirstMiningSchema(final PMMLDocument pmmlDoc, final SchemaType type) {
Map<PMMLModelType, Integer> models = getNumberOfModels(pmmlDoc);
if (!models.containsKey(PMMLModelType.getType(type))) {
return null;
}
PMML pmml = pmmlDoc.getPMML();
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
return model.getMiningSchema();
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
return model.getMiningSchema();
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
return model.getMiningSchema();
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
return model.getMiningSchema();
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
return model.getMiningSchema();
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
return model.getMiningSchema();
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
return model.getMiningSchema();
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
return model.getMiningSchema();
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
return model.getMiningSchema();
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
return model.getMiningSchema();
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
return model.getMiningSchema();
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
return model.getMiningSchema();
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
return model.getMiningSchema();
} else {
return null;
}
}
use of org.dmg.pmml.RuleSetModelDocument.RuleSetModel in project knime-core by knime.
the class PMMLPortObject method moveDerivedFields.
/**
* Moves the content of the transformation dictionary to local
* transformations.
* @param type the type of model to move the derived fields to
* @return the {@link LocalTransformations} element containing the moved
* derived fields or an empty local transformation object if nothing
* has to be moved
*/
private LocalTransformations moveDerivedFields(final SchemaType type) {
PMML pmml = m_pmmlDoc.getPMML();
TransformationDictionary transDict = pmml.getTransformationDictionary();
LocalTransformations localTrans = LocalTransformations.Factory.newInstance();
if (transDict == null) {
// nothing to be moved
return localTrans;
}
localTrans.setDerivedFieldArray(transDict.getDerivedFieldArray());
localTrans.setExtensionArray(transDict.getExtensionArray());
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
boolean known = true;
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
model.setLocalTransformations(localTrans);
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
model.setLocalTransformations(localTrans);
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
model.setLocalTransformations(localTrans);
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
model.setLocalTransformations(localTrans);
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
model.setLocalTransformations(localTrans);
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
model.setLocalTransformations(localTrans);
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
model.setLocalTransformations(localTrans);
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
model.setLocalTransformations(localTrans);
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
model.setLocalTransformations(localTrans);
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
model.setLocalTransformations(localTrans);
} else {
if (type != null) {
LOGGER.error("Could not move TransformationDictionary to " + "unsupported model of type \"" + type + "\".");
}
known = false;
}
if (known) {
// remove derived fields from TransformationDictionary
transDict.setDerivedFieldArray(new DerivedField[0]);
transDict.setExtensionArray(new ExtensionDocument.Extension[0]);
}
return localTrans;
}
use of org.dmg.pmml.RuleSetModelDocument.RuleSetModel in project knime-core by knime.
the class RuleEngine2PortsNodeModel method computeRearrangerWithPMML.
/**
* @param spec
* @param rules
* @param flowVars
* @param ruleIdx
* @param outcomeIdx
* @param confidenceIdx
* @param weightIdx
* @param validationIdx
* @param outputColumnName
* @return
* @throws InterruptedException
* @throws InvalidSettingsException
*/
private Pair<ColumnRearranger, PortObject> computeRearrangerWithPMML(final DataTableSpec spec, final RowInput rules, final Map<String, FlowVariable> flowVars, final int ruleIdx, final int outcomeIdx, final int confidenceIdx, final int weightIdx, final int validationIdx, final String outputColumnName) throws InterruptedException, InvalidSettingsException {
PortObject po;
ColumnRearranger ret;
PMMLDocument doc = PMMLDocument.Factory.newInstance();
final PMML pmmlObj = doc.addNewPMML();
RuleSetModel ruleSetModel = pmmlObj.addNewRuleSetModel();
RuleSet ruleSet = ruleSetModel.addNewRuleSet();
List<DataType> outcomeTypes = new ArrayList<>();
PMMLRuleParser parser = new PMMLRuleParser(spec, flowVars);
int lineNo = 0;
DataRow ruleRow;
while ((ruleRow = rules.poll()) != null) {
++lineNo;
DataCell rule = ruleRow.getCell(ruleIdx);
CheckUtils.checkSetting(!rule.isMissing(), "Missing rule in row: " + ruleRow.getKey());
if (rule instanceof StringValue) {
StringValue ruleText = (StringValue) rule;
String r = ruleText.getStringValue().replaceAll("[\r\n]+", " ");
if (RuleSupport.isComment(r)) {
continue;
}
if (outcomeIdx >= 0) {
r += " => " + m_settings.asStringFailForMissing(ruleRow.getCell(outcomeIdx));
}
ParseState state = new ParseState(r);
try {
PMMLPredicate condition = parser.parseBooleanExpression(state);
SimpleRule simpleRule = ruleSet.addNewSimpleRule();
setCondition(simpleRule, condition);
state.skipWS();
state.consumeText("=>");
state.skipWS();
Expression outcome = parser.parseOutcomeOperand(state, null);
simpleRule.setScore(outcome.toString());
if (confidenceIdx >= 0) {
DataCell confidenceCell = ruleRow.getCell(confidenceIdx);
if (!confidenceCell.isMissing()) {
if (confidenceCell instanceof DoubleValue) {
DoubleValue dv = (DoubleValue) confidenceCell;
double confidence = dv.getDoubleValue();
simpleRule.setConfidence(confidence);
}
}
}
if (weightIdx >= 0) {
DataCell weightCell = ruleRow.getCell(weightIdx);
boolean missing = true;
if (!weightCell.isMissing()) {
if (weightCell instanceof DoubleValue) {
DoubleValue dv = (DoubleValue) weightCell;
double weight = dv.getDoubleValue();
simpleRule.setWeight(weight);
missing = false;
}
}
if (missing && m_settings.isHasDefaultWeight()) {
simpleRule.setWeight(m_settings.getDefaultWeight());
}
}
CheckUtils.checkSetting(outcome.isConstant(), "Outcome is not constant in line " + lineNo + " (" + ruleRow.getKey() + ") for rule: " + rule);
outcomeTypes.add(outcome.getOutputType());
} catch (ParseException e) {
ParseException error = Util.addContext(e, r, lineNo);
throw new InvalidSettingsException("Wrong rule in line: " + ruleRow.getKey() + "\n" + error.getMessage(), error);
}
} else {
CheckUtils.checkSetting(false, "Wrong type (" + rule.getType() + ") of rule: " + rule + "\nin row: " + ruleRow.getKey());
}
}
ColumnRearranger dummy = new ColumnRearranger(spec);
if (!m_settings.isReplaceColumn()) {
dummy.append(new SingleCellFactory(new DataColumnSpecCreator(outputColumnName, RuleEngineNodeModel.computeOutputType(outcomeTypes, computeOutcomeType(rules.getDataTableSpec()), true, m_settings.isDisallowLongOutputForCompatibility())).createSpec()) {
@Override
public DataCell getCell(final DataRow row) {
return null;
}
});
}
PMMLPortObject pmml = createPMMLPortObject(doc, ruleSetModel, ruleSet, parser, dummy.createSpec());
po = pmml;
m_copy = copy(pmml);
String predictionConfidenceColumn = m_settings.getPredictionConfidenceColumn();
if (predictionConfidenceColumn == null || predictionConfidenceColumn.isEmpty()) {
predictionConfidenceColumn = RuleEngine2PortsSettings.DEFAULT_PREDICTION_CONFIDENCE_COLUMN;
}
ret = PMMLRuleSetPredictorNodeModel.createRearranger(pmml, spec, m_settings.isReplaceColumn(), outputColumnName, m_settings.isComputeConfidence(), DataTableSpec.getUniqueColumnName(dummy.createSpec(), predictionConfidenceColumn), validationIdx);
return Pair.create(ret, po);
}
use of org.dmg.pmml.RuleSetModelDocument.RuleSetModel in project knime-core by knime.
the class PMMLRuleEditorNodeModel method createRearrangerAndPMMLModel.
private RearrangerAndPMMLModel createRearrangerAndPMMLModel(final DataTableSpec spec) throws ParseException, InvalidSettingsException {
final PMMLDocument doc = PMMLDocument.Factory.newInstance();
final PMML pmml = doc.addNewPMML();
RuleSetModel ruleSetModel = pmml.addNewRuleSetModel();
RuleSet ruleSet = ruleSetModel.addNewRuleSet();
PMMLRuleParser parser = new PMMLRuleParser(spec, getAvailableInputFlowVariables());
ColumnRearranger rearranger = createRearranger(spec, ruleSet, parser);
PMMLPortObject ret = new PMMLPortObject(createPMMLPortObjectSpec(rearranger.createSpec(), parser.getUsedColumns()));
// if (inData[1] != null) {
// PMMLPortObject po = (PMMLPortObject)inData[1];
// TransformationDictionary dict = TransformationDictionary.Factory.newInstance();
// dict.setDerivedFieldArray(po.getDerivedFields());
// ret.addGlobalTransformations(dict);
// }
PMMLRuleTranslator modelTranslator = new PMMLRuleTranslator();
ruleSetModel.setFunctionName(MININGFUNCTION.CLASSIFICATION);
ruleSet.setDefaultConfidence(defaultConfidenceValue());
PMMLMiningSchemaTranslator.writeMiningSchema(ret.getSpec(), ruleSetModel);
PMMLDataDictionaryTranslator ddTranslator = new PMMLDataDictionaryTranslator();
ddTranslator.exportTo(doc, ret.getSpec());
modelTranslator.initializeFrom(doc);
ret.addModelTranslater(modelTranslator);
ret.validate();
return new RearrangerAndPMMLModel(rearranger, ret);
}
Aggregations