use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.
the class PMMLMany2OneTranslator method createDerivedField.
private DerivedField createDerivedField() {
final DerivedField derivedField = DerivedField.Factory.newInstance();
derivedField.setName(m_appendedCol);
derivedField.setDataType(DATATYPE.STRING);
derivedField.setOptype(OPTYPE.CATEGORICAL);
Apply parentApply = null;
for (String col : m_sourceCols) {
Apply ifApply;
if (parentApply == null) {
ifApply = derivedField.addNewApply();
} else {
ifApply = parentApply.addNewApply();
}
ifApply.setFunction("if");
Apply innerIf = ifApply.addNewApply();
innerIf.setFunction("equal");
innerIf.addNewFieldRef().setField(col);
if (m_method == IncludeMethod.Maximum || m_method == IncludeMethod.Minimum) {
Apply a = innerIf.addNewApply();
a.setFunction(IncludeMethod.Maximum == m_method ? "max" : "min");
for (String s : m_sourceCols) {
a.addNewFieldRef().setField(s);
}
} else {
// if (m_method == IncludeMethod.Binary) {
innerIf.addNewConstant().setStringValue("1");
}
ifApply.addNewConstant().setStringValue(col);
parentApply = ifApply;
}
if (parentApply != null) {
parentApply.addNewConstant().setStringValue("missing");
}
return derivedField;
}
use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.
the class DataColumnSpecFilterPMMLNodeModel method createPMMLOut.
private PMMLPortObject createPMMLOut(final PMMLPortObject pmmlIn, final DataTableSpec outSpec, final FilterResult res) throws XmlException {
StringBuffer warningBuffer = new StringBuffer();
if (pmmlIn == null) {
return new PMMLPortObject(createPMMLSpec(null, outSpec, res));
} else {
PMMLDocument pmmldoc;
try (LockedSupplier<Document> supplier = pmmlIn.getPMMLValue().getDocumentSupplier()) {
pmmldoc = PMMLDocument.Factory.parse(supplier.get());
}
// Inspect models to check if they use any excluded columns
List<PMMLModelWrapper> models = PMMLModelWrapper.getModelListFromPMMLDocument(pmmldoc);
for (PMMLModelWrapper model : models) {
MiningSchema ms = model.getMiningSchema();
for (MiningField mf : ms.getMiningFieldList()) {
if (isExcluded(mf.getName(), res)) {
if (warningBuffer.length() != 0) {
warningBuffer.append("\n");
}
warningBuffer.append(model.getModelType().name() + " uses excluded column " + mf.getName());
}
}
}
ArrayList<String> warningFields = new ArrayList<String>();
PMML pmml = pmmldoc.getPMML();
// Now check the transformations if they exist
if (pmml.getTransformationDictionary() != null) {
for (DerivedField df : pmml.getTransformationDictionary().getDerivedFieldList()) {
FieldRef fr = df.getFieldRef();
if (fr != null && isExcluded(fr.getField(), res)) {
warningFields.add(fr.getField());
}
Aggregate a = df.getAggregate();
if (a != null && isExcluded(a.getField(), res)) {
warningFields.add(a.getField());
}
Apply ap = df.getApply();
if (ap != null) {
for (FieldRef fieldRef : ap.getFieldRefList()) {
if (isExcluded(fieldRef.getField(), res)) {
warningFields.add(fieldRef.getField());
break;
}
}
}
Discretize d = df.getDiscretize();
if (d != null && isExcluded(d.getField(), res)) {
warningFields.add(d.getField());
}
MapValues mv = df.getMapValues();
if (mv != null) {
for (FieldColumnPair fcp : mv.getFieldColumnPairList()) {
if (isExcluded(fcp.getField(), res)) {
warningFields.add(fcp.getField());
}
}
}
NormContinuous nc = df.getNormContinuous();
if (nc != null && isExcluded(nc.getField(), res)) {
warningFields.add(nc.getField());
}
NormDiscrete nd = df.getNormDiscrete();
if (nd != null && isExcluded(nd.getField(), res)) {
warningFields.add(nd.getField());
}
}
}
DataDictionary dict = pmml.getDataDictionary();
List<DataField> fields = dict.getDataFieldList();
// Apply filter to spec
int numFields = 0;
for (int i = fields.size() - 1; i >= 0; i--) {
if (isExcluded(fields.get(i).getName(), res)) {
dict.removeDataField(i);
} else {
numFields++;
}
}
dict.setNumberOfFields(new BigInteger(Integer.toString(numFields)));
pmml.setDataDictionary(dict);
pmmldoc.setPMML(pmml);
// generate warnings and set as warning message
for (String s : warningFields) {
if (warningBuffer.length() != 0) {
warningBuffer.append("\n");
}
warningBuffer.append("Transformation dictionary uses excluded column " + s);
}
if (warningBuffer.length() > 0) {
setWarningMessage(warningBuffer.toString().trim());
}
PMMLPortObject outport = null;
try {
outport = new PMMLPortObject(createPMMLSpec(pmmlIn.getSpec(), outSpec, res), pmmldoc);
} catch (IllegalArgumentException e) {
if (res.getIncludes().length == 0) {
throw new IllegalArgumentException("Excluding all columns produces invalid PMML", e);
} else {
throw e;
}
}
return outport;
}
}
use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.
the class PMMLMany2OneTranslator method createDerivedField.
private DerivedField createDerivedField() {
final DerivedField derivedField = DerivedField.Factory.newInstance();
derivedField.setName(m_appendedCol);
derivedField.setDataType(DATATYPE.STRING);
derivedField.setOptype(OPTYPE.CATEGORICAL);
Apply parentApply = null;
for (String col : m_sourceCols) {
Apply ifApply;
if (parentApply == null) {
ifApply = derivedField.addNewApply();
} else {
ifApply = parentApply.addNewApply();
}
ifApply.setFunction("if");
Apply innerIf = ifApply.addNewApply();
innerIf.setFunction("equal");
innerIf.addNewFieldRef().setField(col);
if (m_method == IncludeMethod.Maximum || m_method == IncludeMethod.Minimum) {
Apply a = innerIf.addNewApply();
a.setFunction(IncludeMethod.Maximum == m_method ? "max" : "min");
for (String s : m_sourceCols) {
a.addNewFieldRef().setField(s);
}
} else {
// if (m_method == IncludeMethod.Binary) {
innerIf.addNewConstant().setStringValue("1");
}
ifApply.addNewConstant().setStringValue(col);
parentApply = ifApply;
}
if (parentApply != null) {
parentApply.addNewConstant().setStringValue("missing");
}
return derivedField;
}
use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.
the class MissingCellHandler method createValueReplacingDerivedField.
/**
* Helper method for creating a derived field that replaces a field's value with a fixed value.
* @param dataType the data type of the field.
* @param value the replacement value for the field
* @return the derived field
*/
protected DerivedField createValueReplacingDerivedField(final DATATYPE.Enum dataType, final String value) {
DerivedField field = DerivedField.Factory.newInstance();
if (dataType == org.dmg.pmml.DATATYPE.STRING || dataType == org.dmg.pmml.DATATYPE.BOOLEAN) {
field.setOptype(org.dmg.pmml.OPTYPE.CATEGORICAL);
} else {
field.setOptype(org.dmg.pmml.OPTYPE.CONTINUOUS);
}
/*
* Create the PMML equivalent of: "if fieldVal is missing then x else fieldVal"
* <Apply function="if">
* <Apply function="isMissing">
* <FieldRef field="fieldVal"/>
* </Apply>
* <Constant dataType="___" value="x"/>
* <FieldRef field="fieldVal"/>
* </Apply>
*/
Apply ifApply = field.addNewApply();
ifApply.setFunction(IF_FUNCTION_NAME);
Apply isMissingApply = Apply.Factory.newInstance();
FieldRef fieldRef = FieldRef.Factory.newInstance();
fieldRef.setField(m_col.getName());
isMissingApply.setFieldRefArray(new FieldRef[] { fieldRef });
isMissingApply.setFunction(IS_MISSING_FUNCTION_NAME);
ifApply.setApplyArray(new Apply[] { isMissingApply });
Constant replacement = Constant.Factory.newInstance();
replacement.setDataType(dataType);
replacement.setStringValue(value);
ifApply.setConstantArray(new Constant[] { replacement });
ifApply.setFieldRefArray(new FieldRef[] { fieldRef });
field.setDataType(dataType);
field.setName(m_col.getName());
field.setDisplayName(m_col.getName());
return field;
}
use of org.dmg.pmml.ApplyDocument.Apply in project knime-core by knime.
the class PMMLGeneralRegressionTranslator method exportTo.
/**
* {@inheritDoc}
*/
@Override
public SchemaType exportTo(final PMMLDocument pmmlDoc, final PMMLPortObjectSpec spec) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
GeneralRegressionModel reg = pmmlDoc.getPMML().addNewGeneralRegressionModel();
final JsonObjectBuilder jsonBuilder = Json.createObjectBuilder();
if (!m_content.getVectorLengths().isEmpty()) {
LocalTransformations localTransformations = reg.addNewLocalTransformations();
for (final Entry<? extends String, ? extends Integer> entry : m_content.getVectorLengths().entrySet()) {
DataColumnSpec columnSpec = spec.getDataTableSpec().getColumnSpec(entry.getKey());
if (columnSpec != null) {
final DataType type = columnSpec.getType();
final DataColumnProperties props = columnSpec.getProperties();
final boolean bitVector = type.isCompatible(BitVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "BitVector".equals(props.getProperty("realType")));
final boolean byteVector = type.isCompatible(ByteVectorValue.class) || (type.isCompatible(StringValue.class) && props.containsProperty("realType") && "ByteVector".equals(props.getProperty("realType")));
final String lengthAsString;
final int width;
if (byteVector) {
lengthAsString = "3";
width = 4;
} else if (bitVector) {
lengthAsString = "1";
width = 1;
} else {
throw new UnsupportedOperationException("Not supported type: " + type + " for column: " + columnSpec);
}
for (int i = 0; i < entry.getValue().intValue(); ++i) {
final DerivedField derivedField = localTransformations.addNewDerivedField();
derivedField.setOptype(OPTYPE.CONTINUOUS);
derivedField.setDataType(DATATYPE.INTEGER);
derivedField.setName(entry.getKey() + "[" + i + "]");
Apply apply = derivedField.addNewApply();
apply.setFunction("substring");
apply.addNewFieldRef().setField(entry.getKey());
Constant from = apply.addNewConstant();
from.setDataType(DATATYPE.INTEGER);
from.setStringValue(bitVector ? Long.toString(entry.getValue().longValue() - i) : Long.toString(i * width + 1L));
Constant length = apply.addNewConstant();
length.setDataType(DATATYPE.INTEGER);
length.setStringValue(lengthAsString);
}
}
jsonBuilder.add(entry.getKey(), entry.getValue().intValue());
}
}
// PMMLPortObjectSpecCreator newSpecCreator = new PMMLPortObjectSpecCreator(spec);
// newSpecCreator.addPreprocColNames(m_content.getVectorLengths().entrySet().stream()
// .flatMap(
// e -> IntStream.iterate(0, o -> o + 1).limit(e.getValue()).mapToObj(i -> e.getKey() + "[" + i + "]"))
// .collect(Collectors.toList()));
PMMLMiningSchemaTranslator.writeMiningSchema(spec, reg);
// if (!m_content.getVectorLengths().isEmpty()) {
// Extension miningExtension = reg.getMiningSchema().addNewExtension();
// miningExtension.setExtender(EXTENDER);
// miningExtension.setName(VECTOR_COLUMNS_WITH_LENGTH);
// miningExtension.setValue(jsonBuilder.build().toString());
// }
reg.setModelType(getPMMLRegModelType(m_content.getModelType()));
reg.setFunctionName(getPMMLMiningFunction(m_content.getFunctionName()));
String algorithmName = m_content.getAlgorithmName();
if (algorithmName != null && !algorithmName.isEmpty()) {
reg.setAlgorithmName(algorithmName);
}
String modelName = m_content.getModelName();
if (modelName != null && !modelName.isEmpty()) {
reg.setModelName(modelName);
}
String targetReferenceCategory = m_content.getTargetReferenceCategory();
if (targetReferenceCategory != null && !targetReferenceCategory.isEmpty()) {
reg.setTargetReferenceCategory(targetReferenceCategory);
}
if (m_content.getOffsetValue() != null) {
reg.setOffsetValue(m_content.getOffsetValue());
}
// add parameter list
ParameterList paramList = reg.addNewParameterList();
for (PMMLParameter p : m_content.getParameterList()) {
Parameter param = paramList.addNewParameter();
param.setName(p.getName());
String label = p.getLabel();
if (label != null) {
param.setLabel(m_nameMapper.getDerivedFieldName(label));
}
}
// add factor list
FactorList factorList = reg.addNewFactorList();
for (PMMLPredictor p : m_content.getFactorList()) {
Predictor predictor = factorList.addNewPredictor();
predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
}
// add covariate list
CovariateList covariateList = reg.addNewCovariateList();
for (PMMLPredictor p : m_content.getCovariateList()) {
Predictor predictor = covariateList.addNewPredictor();
predictor.setName(m_nameMapper.getDerivedFieldName(p.getName()));
}
// add PPMatrix
PPMatrix ppMatrix = reg.addNewPPMatrix();
for (PMMLPPCell p : m_content.getPPMatrix()) {
PPCell cell = ppMatrix.addNewPPCell();
cell.setValue(p.getValue());
cell.setPredictorName(m_nameMapper.getDerivedFieldName(p.getPredictorName()));
cell.setParameterName(p.getParameterName());
String targetCategory = p.getTargetCategory();
if (targetCategory != null && !targetCategory.isEmpty()) {
cell.setTargetCategory(targetCategory);
}
}
// add CovMatrix
if (m_content.getPCovMatrix().length > 0) {
PCovMatrix pCovMatrix = reg.addNewPCovMatrix();
for (PMMLPCovCell p : m_content.getPCovMatrix()) {
PCovCell covCell = pCovMatrix.addNewPCovCell();
covCell.setPRow(p.getPRow());
covCell.setPCol(p.getPCol());
String tCol = p.getTCol();
String tRow = p.getTRow();
if (tRow != null || tCol != null) {
covCell.setTRow(tRow);
covCell.setTCol(tCol);
}
covCell.setValue(p.getValue());
String targetCategory = p.getTargetCategory();
if (targetCategory != null && !targetCategory.isEmpty()) {
covCell.setTargetCategory(targetCategory);
}
}
}
// add ParamMatrix
ParamMatrix paramMatrix = reg.addNewParamMatrix();
for (PMMLPCell p : m_content.getParamMatrix()) {
PCell pCell = paramMatrix.addNewPCell();
String targetCategory = p.getTargetCategory();
if (targetCategory != null) {
pCell.setTargetCategory(targetCategory);
}
pCell.setParameterName(p.getParameterName());
pCell.setBeta(p.getBeta());
Integer df = p.getDf();
if (df != null) {
pCell.setDf(BigInteger.valueOf(df));
}
}
return GeneralRegressionModel.type;
}
Aggregations