use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.
the class PMMLPortObject method moveGlobalTransformationsToModel.
/**
* Moves the content of the transformation dictionary to local
* transformations of the model if a model exists.
*/
public void moveGlobalTransformationsToModel() {
PMML pmml = m_pmmlDoc.getPMML();
TransformationDictionary transDict = pmml.getTransformationDictionary();
if (transDict == null || transDict.getDerivedFieldArray() == null || transDict.getDerivedFieldArray().length == 0) {
// nothing to be moved
return;
}
DerivedField[] globalDerivedFields = transDict.getDerivedFieldArray();
LocalTransformations localTrans = null;
if (pmml.getTreeModelArray().length > 0) {
TreeModel model = pmml.getTreeModelArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
} else if (pmml.getClusteringModelArray().length > 0) {
ClusteringModel model = pmml.getClusteringModelArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
} else if (pmml.getNeuralNetworkArray().length > 0) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
} else if (pmml.getSupportVectorMachineModelArray().length > 0) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
} else if (pmml.getRegressionModelArray().length > 0) {
RegressionModel model = pmml.getRegressionModelArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
} else if (pmml.getGeneralRegressionModelArray().length > 0) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
} else if (pmml.sizeOfRuleSetModelArray() > 0) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
localTrans = model.getLocalTransformations();
if (localTrans == null) {
localTrans = model.addNewLocalTransformations();
}
}
if (localTrans != null) {
DerivedField[] derivedFields = appendDerivedFields(localTrans.getDerivedFieldArray(), globalDerivedFields);
localTrans.setDerivedFieldArray(derivedFields);
// remove derived fields from TransformationDictionary
transDict.setDerivedFieldArray(new DerivedField[0]);
}
// else do nothing as no model exists yet
}
use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.
the class PMMLModelWrapper method getSegmentContent.
/**
* Returns the content of a segment as a model wrapper.
* @param s The segment
* @return Returns a wrapper around the model
*/
public static PMMLModelWrapper getSegmentContent(final Segment s) {
TreeModel treemodel = s.getTreeModel();
if (treemodel != null) {
return new PMMLTreeModelWrapper(treemodel);
}
RegressionModel regrmodel = s.getRegressionModel();
if (regrmodel != null) {
return new PMMLRegressionModelWrapper(regrmodel);
}
GeneralRegressionModel genregrmodel = s.getGeneralRegressionModel();
if (genregrmodel != null) {
return new PMMLGeneralRegressionModelWrapper(genregrmodel);
}
ClusteringModel clustmodel = s.getClusteringModel();
if (clustmodel != null) {
return new PMMLClusteringModelWrapper(clustmodel);
}
NaiveBayesModel nbmodel = s.getNaiveBayesModel();
if (nbmodel != null) {
return new PMMLNaiveBayesModelWrapper(nbmodel);
}
NeuralNetwork nn = s.getNeuralNetwork();
if (nn != null) {
return new PMMLNeuralNetworkWrapper(nn);
}
RuleSetModel rsmodel = s.getRuleSetModel();
if (rsmodel != null) {
return new PMMLRuleSetModelWrapper(rsmodel);
}
SupportVectorMachineModel svmmodel = s.getSupportVectorMachineModel();
if (svmmodel != null) {
return new PMMLSupportVectorMachineModelWrapper(svmmodel);
}
return null;
}
use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.
the class PMMLGeneralRegressionTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
List<GeneralRegressionModel> models = pmmlDoc.getPMML().getGeneralRegressionModelList();
if (models.isEmpty()) {
throw new IllegalArgumentException("No general regression model" + " provided.");
} else if (models.size() > 1) {
LOGGER.warn("Multiple general regression models found. " + "Only the first model is considered.");
}
GeneralRegressionModel reg = models.get(0);
// read the content type
PMMLGeneralRegressionContent.ModelType modelType = getKNIMERegModelType(reg.getModelType());
m_content.setModelType(modelType);
// read the function name
FunctionName functionName = getKNIMEFunctionName(reg.getFunctionName());
m_content.setFunctionName(functionName);
m_content.setAlgorithmName(reg.getAlgorithmName());
m_content.setModelName(reg.getModelName());
if (reg.getCumulativeLink() != null) {
throw new IllegalArgumentException("The attribute \"cumulativeLink\"" + " is currently not supported.");
}
m_content.setTargetReferenceCategory(reg.getTargetReferenceCategory());
if (reg.isSetOffsetValue()) {
m_content.setOffsetValue(reg.getOffsetValue());
}
if (reg.getLocalTransformations() != null && reg.getLocalTransformations().getDerivedFieldList() != null) {
updateVectorLengthsBasedOnDerivedFields(reg.getLocalTransformations().getDerivedFieldList());
}
// final Stream<String> vectorLengthsAsJsonAsString = reg.getMiningSchema().getExtensionList().stream()
// .filter(e -> e.getExtender().equals(EXTENDER) && e.getName().equals(VECTOR_COLUMNS_WITH_LENGTH)).map(v -> v.getValue());
// vectorLengthsAsJsonAsString
// .forEachOrdered(jsonAsString -> m_content.updateVectorLengths(
// Json.createReader(new StringReader(jsonAsString)).readObject().entrySet().stream().collect(
// Collectors.toMap(Entry::getKey, entry -> ((JsonNumber)entry.getValue()).intValueExact()))));
// read the parameter list
ParameterList pmmlParamList = reg.getParameterList();
if (pmmlParamList != null && pmmlParamList.sizeOfParameterArray() > 0) {
List<Parameter> pmmlParam = pmmlParamList.getParameterList();
PMMLParameter[] paramList = new PMMLParameter[pmmlParam.size()];
for (int i = 0; i < pmmlParam.size(); i++) {
String name = m_nameMapper.getColumnName(pmmlParam.get(i).getName());
String label = pmmlParam.get(i).getLabel();
if (label == null) {
paramList[i] = new PMMLParameter(name);
} else {
paramList[i] = new PMMLParameter(name, label);
}
}
m_content.setParameterList(paramList);
} else {
m_content.setParameterList(new PMMLParameter[0]);
}
// read the factor list
FactorList pmmlFactorList = reg.getFactorList();
if (pmmlFactorList != null && pmmlFactorList.sizeOfPredictorArray() > 0) {
List<Predictor> pmmlPredictor = pmmlFactorList.getPredictorList();
PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
for (int i = 0; i < pmmlPredictor.size(); i++) {
predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
}
m_content.setFactorList(predictor);
} else {
m_content.setFactorList(new PMMLPredictor[0]);
}
// read covariate list
CovariateList covariateList = reg.getCovariateList();
if (covariateList != null && covariateList.sizeOfPredictorArray() > 0) {
List<Predictor> pmmlPredictor = covariateList.getPredictorList();
PMMLPredictor[] predictor = new PMMLPredictor[pmmlPredictor.size()];
for (int i = 0; i < pmmlPredictor.size(); i++) {
predictor[i] = new PMMLPredictor(m_nameMapper.getColumnName(pmmlPredictor.get(i).getName()));
}
m_content.setCovariateList(predictor);
} else {
m_content.setCovariateList(new PMMLPredictor[0]);
}
// read PPMatrix
PPMatrix ppMatrix = reg.getPPMatrix();
if (ppMatrix != null && ppMatrix.sizeOfPPCellArray() > 0) {
List<PPCell> pmmlCellArray = ppMatrix.getPPCellList();
PMMLPPCell[] cells = new PMMLPPCell[pmmlCellArray.size()];
for (int i = 0; i < pmmlCellArray.size(); i++) {
PPCell ppCell = pmmlCellArray.get(i);
cells[i] = new PMMLPPCell(ppCell.getValue(), m_nameMapper.getColumnName(ppCell.getPredictorName()), ppCell.getParameterName(), ppCell.getTargetCategory());
}
m_content.setPPMatrix(cells);
} else {
m_content.setPPMatrix(new PMMLPPCell[0]);
}
// read CovMatrix
PCovMatrix pCovMatrix = reg.getPCovMatrix();
if (pCovMatrix != null && pCovMatrix.sizeOfPCovCellArray() > 0) {
List<PCovCell> pCovCellArray = pCovMatrix.getPCovCellList();
PMMLPCovCell[] covCells = new PMMLPCovCell[pCovCellArray.size()];
for (int i = 0; i < pCovCellArray.size(); i++) {
PCovCell c = pCovCellArray.get(i);
covCells[i] = new PMMLPCovCell(c.getPRow(), c.getPCol(), c.getTRow(), c.getTCol(), c.getValue(), c.getTargetCategory());
}
m_content.setPCovMatrix(covCells);
} else {
m_content.setPCovMatrix(new PMMLPCovCell[0]);
}
// read ParamMatrix
ParamMatrix paramMatrix = reg.getParamMatrix();
if (paramMatrix != null && paramMatrix.sizeOfPCellArray() > 0) {
List<PCell> pCellArray = paramMatrix.getPCellList();
PMMLPCell[] cells = new PMMLPCell[pCellArray.size()];
for (int i = 0; i < pCellArray.size(); i++) {
PCell p = pCellArray.get(i);
double beta = p.getBeta();
BigInteger df = p.getDf();
if (df != null) {
cells[i] = new PMMLPCell(p.getParameterName(), beta, df.intValue(), p.getTargetCategory());
} else {
cells[i] = new PMMLPCell(p.getParameterName(), beta, p.getTargetCategory());
}
}
m_content.setParamMatrix(cells);
} else {
m_content.setParamMatrix(new PMMLPCell[0]);
}
}
use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.
the class PMMLUtils method getFirstMiningSchema.
/**
* Retrieves the mining schema of the first model of a specific type.
*
* @param pmmlDoc the PMML document to extract the mining schema from
* @param type the type of the model
* @return the mining schema of the first model of the given type or null if
* there is no model of the given type contained in the pmmlDoc
*/
public static MiningSchema getFirstMiningSchema(final PMMLDocument pmmlDoc, final SchemaType type) {
Map<PMMLModelType, Integer> models = getNumberOfModels(pmmlDoc);
if (!models.containsKey(PMMLModelType.getType(type))) {
return null;
}
PMML pmml = pmmlDoc.getPMML();
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
return model.getMiningSchema();
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
return model.getMiningSchema();
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
return model.getMiningSchema();
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
return model.getMiningSchema();
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
return model.getMiningSchema();
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
return model.getMiningSchema();
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
return model.getMiningSchema();
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
return model.getMiningSchema();
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
return model.getMiningSchema();
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
return model.getMiningSchema();
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
return model.getMiningSchema();
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
return model.getMiningSchema();
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
return model.getMiningSchema();
} else {
return null;
}
}
use of org.dmg.pmml.GeneralRegressionModelDocument.GeneralRegressionModel in project knime-core by knime.
the class PMMLGeneralRegressionTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
GeneralRegressionModel reg = pmmlDoc.getPMML().addNewGeneralRegressionModel();
final JsonObjectBuilder jsonBuilder = Json.createObjectBuilder();
if (!m_content.getVectorLengths().isEmpty()) {
LocalTransformations localTransformations = reg.addNewLocalTransformations();
for (final Entry<? extends String, ? extends Integer> entry : m_content.getVectorLengths().entrySet()) {
DataColumnSpec columnSpec = spec.getDataTableSpec().getColumnSpec(entry.getKey());
if (columnSpec != null) {
final DataType type = columnSpec.getType();
final DataColumnProperties props = columnSpec.getProperties();
final boolean bitVector = type.isCompatible(BitVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "BitVector".equals(props.getProperty("realType")));
final boolean byteVector = type.isCompatible(ByteVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "ByteVector".equals(props.getProperty("realType")));
final String lengthAsString;
final int width;
if (byteVector) {
lengthAsString = "3";
width = 4;
} else if (bitVector) {
lengthAsString = "1";
width = 1;
} else {
throw new UnsupportedOperationException("Not supported type: " + type + " for column: " + columnSpec);
}
for (int i = 0; i < entry.getValue().intValue(); ++i) {
final DerivedField derivedField = localTransformations.addNewDerivedField();
derivedField.setOptype(OPTYPE.CONTINUOUS);
derivedField.setDataType(DATATYPE.INTEGER);
derivedField.setName(entry.getKey() + "[" + i + "]");
Apply apply = derivedField.addNewApply();
apply.setFunction("substring");
apply.addNewFieldRef().setField(entry.getKey());
Constant from = apply.addNewConstant();
from.setDataType(DATATYPE.INTEGER);
from.setStringValue(bitVector ? Long.toString(entry.getValue().longValue() - i) : Long.toString(i * width + 1L));
Constant length = apply.addNewConstant();
length.setDataType(DATATYPE.INTEGER);
length.setStringValue(lengthAsString);
}
}
jsonBuilder.add(entry.getKey(), entry.getValue().intValue());
}
}
// PMMLPortObjectSpecCreator newSpecCreator = new PMMLPortObjectSpecCreator(spec);
// newSpecCreator.addPreprocColNames(m_content.getVectorLengths().entrySet().stream()
// .flatMap(
// e -> IntStream.iterate(0, o -> o + 1).limit(e.getValue()).mapToObj(i -> e.getKey() + "[" + i + "]"))
// .collect(Collectors.toList()));
PMMLMiningSchemaTranslator.writeMiningSchema(spec, reg);
// if (!m_content.getVectorLengths().isEmpty()) {
// Extension miningExtension = reg.getMiningSchema().addNewExtension();
// miningExtension.setExtender(EXTENDER);
// miningExtension.setName(VECTOR_COLUMNS_WITH_LENGTH);
// miningExtension.setValue(jsonBuilder.build().toString());
// }
reg.setModelType(getPMMLRegModelType(m_content.getModelType()));
reg.setFunctionName(getPMMLMiningFunction(m_content.getFunctionName()));
String algorithmName = m_content.getAlgorithmName();
if (algorithmName != null && !algorithmName.isEmpty()) {
reg.setAlgorithmName(algorithmName);
}
String modelName = m_content.getModelName();
if (modelName != null && !modelName.isEmpty()) {
reg.setModelName(modelName);
}
String targetReferenceCategory = m_content.getTargetReferenceCategory();
if (targetReferenceCategory != null && !targetReferenceCategory.isEmpty()) {
reg.setTargetReferenceCategory(targetReferenceCategory);
}
if (m_content.getOffsetValue() != null) {
reg.setOffsetValue(m_content.getOffsetValue());
}
// add parameter list
ParameterList paramList = reg.addNewParameterList();
for (PMMLParameter p : m_content.getParameterList()) {
Parameter param = paramList.addNewParameter();
param.setName(p.getName());
String label = p.getLabel();
if (label != null) {
param.setLabel(m_nameMapper.getDerivedFieldName(label));
}
}
// add factor list
FactorList factorList = reg.addNewFactorList();
for (PMMLPredictor p : m_content.getFactorList()) {
Predictor predictor = factorList.addNewPredictor();
predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
}
// add covariate list
CovariateList covariateList = reg.addNewCovariateList();
for (PMMLPredictor p : m_content.getCovariateList()) {
Predictor predictor = covariateList.addNewPredictor();
predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
}
// add PPMatrix
PPMatrix ppMatrix = reg.addNewPPMatrix();
for (PMMLPPCell p : m_content.getPPMatrix()) {
PPCell cell = ppMatrix.addNewPPCell();
cell.setValue(p.getValue());
cell.setPredictorName(m_nameMapper.getDerivedFieldName(p.getPredictorName()));
cell.setParameterName(p.getParameterName());
String targetCategory = p.getTargetCategory();
if (targetCategory != null && !targetCategory.isEmpty()) {
cell.setTargetCategory(targetCategory);
}
}
// add CovMatrix
if (m_content.getPCovMatrix().length > 0) {
PCovMatrix pCovMatrix = reg.addNewPCovMatrix();
for (PMMLPCovCell p : m_content.getPCovMatrix()) {
PCovCell covCell = pCovMatrix.addNewPCovCell();
covCell.setPRow(p.getPRow());
covCell.setPCol(p.getPCol());
String tCol = p.getTCol();
String tRow = p.getTRow();
if (tRow != null || tCol != null) {
covCell.setTRow(tRow);
covCell.setTCol(tCol);
}
covCell.setValue(p.getValue());
String targetCategory = p.getTargetCategory();
if (targetCategory != null && !targetCategory.isEmpty()) {
covCell.setTargetCategory(targetCategory);
}
}
}
// add ParamMatrix
ParamMatrix paramMatrix = reg.addNewParamMatrix();
for (PMMLPCell p : m_content.getParamMatrix()) {
PCell pCell = paramMatrix.addNewPCell();
String targetCategory = p.getTargetCategory();
if (targetCategory != null) {
pCell.setTargetCategory(targetCategory);
}
pCell.setParameterName(p.getParameterName());
pCell.setBeta(p.getBeta());
Integer df = p.getDf();
if (df != null) {
pCell.setDf(BigInteger.valueOf(df));
}
}
return GeneralRegressionModel.type;
}
Aggregations