use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.
the class PMMLNeuralNetworkTranslator method initiateNeuralOutputs.
/**
* @param nnModel
* the PMML neural network model
*/
private void initiateNeuralOutputs(final NeuralNetwork nnModel) {
NeuralOutputs neuralOutputs = nnModel.getNeuralOutputs();
m_classmap = new HashMap<DataCell, Integer>();
for (NeuralOutput no : neuralOutputs.getNeuralOutputArray()) {
m_curPercpetronID = no.getOutputNeuron();
DerivedField df = no.getDerivedField();
if (df.isSetNormDiscrete()) {
String value = df.getNormDiscrete().getValue();
int pos = m_idPosMap.get(m_curPercpetronID);
m_classmap.put(new StringCell(value), pos);
} else if (df.isSetFieldRef()) {
int pos = m_idPosMap.get(m_curPercpetronID);
m_classmap.put(new StringCell(df.getFieldRef().getField()), pos);
} else {
LOGGER.error("The expression is not supported in KNIME MLP.");
}
}
}
use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.
the class DBApplyBinnerNodeModel method createQuery.
private String createQuery(final String query, final StatementManipulator statementManipulator, final DataTableSpec dataTableSpec, final PMMLPortObject pmmlPortObject) throws InvalidSettingsException {
DBBinnerMaps maps = DBAutoBinner.intoBinnerMaps(pmmlPortObject, dataTableSpec);
DerivedField[] derivedFields = pmmlPortObject.getDerivedFields();
final DataTableSpec pmmlInputSpec = pmmlPortObject.getSpec().getDataTableSpec();
String[] includeCols = new String[derivedFields.length];
for (int i = 0; i < pmmlPortObject.getDerivedFields().length; i++) {
String fieldName = derivedFields[i].getDiscretize().getField();
final DataColumnSpec colSpec = dataTableSpec.getColumnSpec(fieldName);
if (colSpec == null) {
throw new InvalidSettingsException("Column '" + fieldName + "' not found in input table");
}
DataColumnSpec pmmlInputColSpec = pmmlInputSpec.getColumnSpec(fieldName);
assert (pmmlInputColSpec != null) : "Column '" + fieldName + "' from derived fields not found in PMML model spec";
DataType knimeType = pmmlInputColSpec.getType();
if (!colSpec.getType().isCompatible(knimeType.getPreferredValueClass())) {
throw new InvalidSettingsException("Date type of column '" + fieldName + "' is not compatible with PMML model: expected '" + knimeType + "' but is '" + colSpec.getType() + "'");
}
includeCols[i] = fieldName;
}
String[] allColumns = dataTableSpec.getColumnNames();
String[] excludeCols = filter(includeCols, allColumns);
String result = statementManipulator.getBinnerStatement(query, includeCols, excludeCols, maps.getBoundariesMap(), maps.getBoundariesOpenMap(), maps.getNamingMap(), maps.getAppendMap());
return result;
}
use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.
the class PMMLMapValuesTranslator method createDerivedFields.
private DerivedField[] createDerivedFields() {
DerivedField df = DerivedField.Factory.newInstance();
df.setExtensionArray(createSummaryExtension());
/* The field name must be retrieved before creating a new derived
* name for this derived field as the map only contains the
* current mapping. */
String fieldName = m_mapper.getDerivedFieldName(m_config.getInColumn());
if (m_config.getInColumn().equals(m_config.getOutColumn())) {
String name = m_config.getInColumn();
df.setDisplayName(name);
df.setName(m_mapper.createDerivedFieldName(name));
} else {
df.setName(m_config.getOutColumn());
}
df.setOptype(m_config.getOpType());
df.setDataType(m_config.getOutDataType());
MapValues mapValues = df.addNewMapValues();
// the element in the InlineTable representing the output column
// Use dummy name instead of m_config.getOutColumn() since the
// input column could contain characters that are not allowed in XML
final QName xmlOut = new QName("http://www.dmg.org/PMML-4_0", "out");
mapValues.setOutputColumn(xmlOut.getLocalPart());
mapValues.setDataType(m_config.getOutDataType());
if (!m_config.getDefaultValue().isMissing()) {
mapValues.setDefaultValue(m_config.getDefaultValue().toString());
}
if (!m_config.getMapMissingTo().isMissing()) {
mapValues.setMapMissingTo(m_config.getMapMissingTo().toString());
}
// the mapping of input field <-> element in the InlineTable
FieldColumnPair fieldColPair = mapValues.addNewFieldColumnPair();
fieldColPair.setField(fieldName);
// Use dummy name instead of m_config.getInColumn() since the
// input column could contain characters that are not allowed in XML
final QName xmlIn = new QName("http://www.dmg.org/PMML-4_0", "in");
fieldColPair.setColumn(xmlIn.getLocalPart());
InlineTable table = mapValues.addNewInlineTable();
for (Entry<DataCell, ? extends DataCell> entry : m_config.getEntries().entrySet()) {
Row row = table.addNewRow();
XmlCursor cursor = row.newCursor();
cursor.toNextToken();
cursor.insertElementWithText(xmlIn, entry.getKey().toString());
cursor.insertElementWithText(xmlOut, entry.getValue().toString());
cursor.dispose();
}
return new DerivedField[] { df };
}
use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.
the class PMMLNormalizeTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public List<Integer> initializeFrom(final DerivedField[] derivedFields) {
if (derivedFields == null) {
return Collections.EMPTY_LIST;
}
m_mapper = new DerivedFieldMapper(derivedFields);
int num = derivedFields.length;
List<Integer> consumed = new ArrayList<Integer>(num);
if (num > 0) {
parseExtensionArray(derivedFields[0].getExtensionArray());
}
for (int i = 0; i < derivedFields.length; i++) {
DerivedField df = derivedFields[i];
/**
* This field contains the name of the column in KNIME that
* corresponds to the derived field in PMML. This is necessary if
* derived fields are defined on other derived fields and the
* columns in KNIME are replaced with the preprocessed values.
* In this case KNIME has to know the original names (e.g. A) while
* PMML references to A', A'' etc.
*/
String displayName = df.getDisplayName();
if (!df.isSetNormContinuous()) {
// only reading norm continuous other entries are skipped
continue;
}
consumed.add(i);
NormContinuous normContinuous = df.getNormContinuous();
if (normContinuous.getLinearNormArray().length > 2) {
throw new IllegalArgumentException("Only two LinearNorm " + "elements are supported per NormContinuous");
}
// String field = normContinuous.getField();
double[] orig = new double[MAX_NUM_SEGMENTS];
double[] norm = new double[MAX_NUM_SEGMENTS];
LinearNorm[] norms = normContinuous.getLinearNormArray();
for (int j = 0; j < norms.length; j++) {
orig[j] = norms[j].getOrig();
norm[j] = norms[j].getNorm();
}
double scale = (norm[1] - norm[0]) / (orig[1] - orig[0]);
m_scales.add(scale);
m_translations.add(norm[0] - scale * orig[0]);
if (displayName != null) {
m_fields.add(displayName);
} else {
m_fields.add(m_mapper.getColumnName(normContinuous.getField()));
}
}
return consumed;
}
use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.
the class PMMLGeneralRegressionTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
GeneralRegressionModel reg = pmmlDoc.getPMML().addNewGeneralRegressionModel();
final JsonObjectBuilder jsonBuilder = Json.createObjectBuilder();
if (!m_content.getVectorLengths().isEmpty()) {
LocalTransformations localTransformations = reg.addNewLocalTransformations();
for (final Entry<? extends String, ? extends Integer> entry : m_content.getVectorLengths().entrySet()) {
DataColumnSpec columnSpec = spec.getDataTableSpec().getColumnSpec(entry.getKey());
if (columnSpec != null) {
final DataType type = columnSpec.getType();
final DataColumnProperties props = columnSpec.getProperties();
final boolean bitVector = type.isCompatible(BitVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "BitVector".equals(props.getProperty("realType")));
final boolean byteVector = type.isCompatible(ByteVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "ByteVector".equals(props.getProperty("realType")));
final String lengthAsString;
final int width;
if (byteVector) {
lengthAsString = "3";
width = 4;
} else if (bitVector) {
lengthAsString = "1";
width = 1;
} else {
throw new UnsupportedOperationException("Not supported type: " + type + " for column: " + columnSpec);
}
for (int i = 0; i < entry.getValue().intValue(); ++i) {
final DerivedField derivedField = localTransformations.addNewDerivedField();
derivedField.setOptype(OPTYPE.CONTINUOUS);
derivedField.setDataType(DATATYPE.INTEGER);
derivedField.setName(entry.getKey() + "[" + i + "]");
Apply apply = derivedField.addNewApply();
apply.setFunction("substring");
apply.addNewFieldRef().setField(entry.getKey());
Constant from = apply.addNewConstant();
from.setDataType(DATATYPE.INTEGER);
from.setStringValue(bitVector ? Long.toString(entry.getValue().longValue() - i) : Long.toString(i * width + 1L));
Constant length = apply.addNewConstant();
length.setDataType(DATATYPE.INTEGER);
length.setStringValue(lengthAsString);
}
}
jsonBuilder.add(entry.getKey(), entry.getValue().intValue());
}
}
// PMMLPortObjectSpecCreator newSpecCreator = new PMMLPortObjectSpecCreator(spec);
// newSpecCreator.addPreprocColNames(m_content.getVectorLengths().entrySet().stream()
// .flatMap(
// e -> IntStream.iterate(0, o -> o + 1).limit(e.getValue()).mapToObj(i -> e.getKey() + "[" + i + "]"))
// .collect(Collectors.toList()));
PMMLMiningSchemaTranslator.writeMiningSchema(spec, reg);
// if (!m_content.getVectorLengths().isEmpty()) {
// Extension miningExtension = reg.getMiningSchema().addNewExtension();
// miningExtension.setExtender(EXTENDER);
// miningExtension.setName(VECTOR_COLUMNS_WITH_LENGTH);
// miningExtension.setValue(jsonBuilder.build().toString());
// }
reg.setModelType(getPMMLRegModelType(m_content.getModelType()));
reg.setFunctionName(getPMMLMiningFunction(m_content.getFunctionName()));
String algorithmName = m_content.getAlgorithmName();
if (algorithmName != null && !algorithmName.isEmpty()) {
reg.setAlgorithmName(algorithmName);
}
String modelName = m_content.getModelName();
if (modelName != null && !modelName.isEmpty()) {
reg.setModelName(modelName);
}
String targetReferenceCategory = m_content.getTargetReferenceCategory();
if (targetReferenceCategory != null && !targetReferenceCategory.isEmpty()) {
reg.setTargetReferenceCategory(targetReferenceCategory);
}
if (m_content.getOffsetValue() != null) {
reg.setOffsetValue(m_content.getOffsetValue());
}
// add parameter list
ParameterList paramList = reg.addNewParameterList();
for (PMMLParameter p : m_content.getParameterList()) {
Parameter param = paramList.addNewParameter();
param.setName(p.getName());
String label = p.getLabel();
if (label != null) {
param.setLabel(m_nameMapper.getDerivedFieldName(label));
}
}
// add factor list
FactorList factorList = reg.addNewFactorList();
for (PMMLPredictor p : m_content.getFactorList()) {
Predictor predictor = factorList.addNewPredictor();
predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
}
// add covariate list
CovariateList covariateList = reg.addNewCovariateList();
for (PMMLPredictor p : m_content.getCovariateList()) {
Predictor predictor = covariateList.addNewPredictor();
predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
}
// add PPMatrix
PPMatrix ppMatrix = reg.addNewPPMatrix();
for (PMMLPPCell p : m_content.getPPMatrix()) {
PPCell cell = ppMatrix.addNewPPCell();
cell.setValue(p.getValue());
cell.setPredictorName(m_nameMapper.getDerivedFieldName(p.getPredictorName()));
cell.setParameterName(p.getParameterName());
String targetCategory = p.getTargetCategory();
if (targetCategory != null && !targetCategory.isEmpty()) {
cell.setTargetCategory(targetCategory);
}
}
// add CovMatrix
if (m_content.getPCovMatrix().length > 0) {
PCovMatrix pCovMatrix = reg.addNewPCovMatrix();
for (PMMLPCovCell p : m_content.getPCovMatrix()) {
PCovCell covCell = pCovMatrix.addNewPCovCell();
covCell.setPRow(p.getPRow());
covCell.setPCol(p.getPCol());
String tCol = p.getTCol();
String tRow = p.getTRow();
if (tRow != null || tCol != null) {
covCell.setTRow(tRow);
covCell.setTCol(tCol);
}
covCell.setValue(p.getValue());
String targetCategory = p.getTargetCategory();
if (targetCategory != null && !targetCategory.isEmpty()) {
covCell.setTargetCategory(targetCategory);
}
}
}
// add ParamMatrix
ParamMatrix paramMatrix = reg.addNewParamMatrix();
for (PMMLPCell p : m_content.getParamMatrix()) {
PCell pCell = paramMatrix.addNewPCell();
String targetCategory = p.getTargetCategory();
if (targetCategory != null) {
pCell.setTargetCategory(targetCategory);
}
pCell.setParameterName(p.getParameterName());
pCell.setBeta(p.getBeta());
Integer df = p.getDf();
if (df != null) {
pCell.setDf(BigInteger.valueOf(df));
}
}
return GeneralRegressionModel.type;
}
Aggregations