use of org.dmg.pmml.MiningModelDocument.MiningModel in project knime-core by knime.
the class ClassificationGBTModelImporter method processClassSegment.
private Pair<List<TreeModelRegression>, List<Map<TreeNodeSignature, Double>>> processClassSegment(final Segment segment) {
MiningModel m = segment.getMiningModel();
CheckUtils.checkArgument(m.getFunctionName() == MININGFUNCTION.REGRESSION, "The mining function of a class segment mining model must be '%s' but was '%s'.", MININGFUNCTION.REGRESSION, m.getFunctionName());
return readSumSegmentation(m.getSegmentation());
}
use of org.dmg.pmml.MiningModelDocument.MiningModel in project knime-core by knime.
the class ClassificationGBTModelExporter method addClassSegment.
private void addClassSegment(final Segmentation modelChain, final int classIdx) {
Segment cs = modelChain.addNewSegment();
cs.setId(Integer.toString(classIdx + 1));
cs.addNewTrue();
MiningModel cm = cs.addNewMiningModel();
cm.setFunctionName(MININGFUNCTION.REGRESSION);
// write mining schema
PMMLMiningSchemaTranslator.writeMiningSchema(getPMMLSpec(), cm);
addOutput(cm, classIdx);
addTarget(cm);
addSegmentation(cm, classIdx);
}
use of org.dmg.pmml.MiningModelDocument.MiningModel in project knime-core by knime.
the class PMMLUtils method getFirstMiningSchema.
/**
* Retrieves the mining schema of the first model of a specific type.
*
* @param pmmlDoc the PMML document to extract the mining schema from
* @param type the type of the model
* @return the mining schema of the first model of the given type or null if
* there is no model of the given type contained in the pmmlDoc
*/
public static MiningSchema getFirstMiningSchema(final PMMLDocument pmmlDoc, final SchemaType type) {
Map<PMMLModelType, Integer> models = getNumberOfModels(pmmlDoc);
if (!models.containsKey(PMMLModelType.getType(type))) {
return null;
}
PMML pmml = pmmlDoc.getPMML();
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
return model.getMiningSchema();
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
return model.getMiningSchema();
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
return model.getMiningSchema();
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
return model.getMiningSchema();
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
return model.getMiningSchema();
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
return model.getMiningSchema();
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
return model.getMiningSchema();
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
return model.getMiningSchema();
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
return model.getMiningSchema();
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
return model.getMiningSchema();
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
return model.getMiningSchema();
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
return model.getMiningSchema();
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
return model.getMiningSchema();
} else {
return null;
}
}
use of org.dmg.pmml.MiningModelDocument.MiningModel in project knime-core by knime.
the class PMMLPortObject method moveDerivedFields.
/**
* Moves the content of the transformation dictionary to local
* transformations.
* @param type the type of model to move the derived fields to
* @return the {@link LocalTransformations} element containing the moved
* derived fields or an empty local transformation object if nothing
* has to be moved
*/
private LocalTransformations moveDerivedFields(final SchemaType type) {
PMML pmml = m_pmmlDoc.getPMML();
TransformationDictionary transDict = pmml.getTransformationDictionary();
LocalTransformations localTrans = LocalTransformations.Factory.newInstance();
if (transDict == null) {
// nothing to be moved
return localTrans;
}
localTrans.setDerivedFieldArray(transDict.getDerivedFieldArray());
localTrans.setExtensionArray(transDict.getExtensionArray());
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
boolean known = true;
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
model.setLocalTransformations(localTrans);
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
model.setLocalTransformations(localTrans);
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
model.setLocalTransformations(localTrans);
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
model.setLocalTransformations(localTrans);
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
model.setLocalTransformations(localTrans);
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
model.setLocalTransformations(localTrans);
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
model.setLocalTransformations(localTrans);
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
model.setLocalTransformations(localTrans);
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
model.setLocalTransformations(localTrans);
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
model.setLocalTransformations(localTrans);
} else {
if (type != null) {
LOGGER.error("Could not move TransformationDictionary to " + "unsupported model of type \"" + type + "\".");
}
known = false;
}
if (known) {
// remove derived fields from TransformationDictionary
transDict.setDerivedFieldArray(new DerivedField[0]);
transDict.setExtensionArray(new ExtensionDocument.Extension[0]);
}
return localTrans;
}
use of org.dmg.pmml.MiningModelDocument.MiningModel in project knime-core by knime.
the class AbstractGBTModelPMMLTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
PMML pmml = pmmlDoc.getPMML();
if (pmml.getHeader() == null || pmml.getHeader().getApplication() == null || !pmml.getHeader().getApplication().getName().equals("KNIME")) {
throw new IllegalArgumentException("Currently only models created with KNIME are supported.");
}
List<MiningModel> mmList = pmml.getMiningModelList();
if (mmList == null || mmList.isEmpty()) {
throw new IllegalArgumentException("The provided PMML does not contain a Gradient Boosted Trees model.");
}
MiningModel model = mmList.get(0);
MetaDataMapper<TreeTargetNumericColumnMetaData> metaDataMapper = new RegressionMetaDataMapper(pmmlDoc, getTargetFieldName(model));
AbstractGBTModelImporter<M> importer = createImporter(metaDataMapper);
m_gbtModel = importer.importFromPMML(pmml.getMiningModelList().get(0));
m_learnSpec = metaDataMapper.getLearnSpec();
}
Aggregations