use of org.dmg.pmml.NormContinuousDocument.NormContinuous in project knime-core by knime.
the class PMMLNormalizeTranslator method createDerivedFields.
private DerivedField[] createDerivedFields() {
int num = m_affineTrans.getNames().length;
DerivedField[] derivedFields = new DerivedField[num];
for (int i = 0; i < num; i++) {
DerivedField df = DerivedField.Factory.newInstance();
df.setExtensionArray(createSummaryExtension());
String name = m_affineTrans.getNames()[i];
df.setDisplayName(name);
/* The field name must be retrieved before creating a new derived
* name for this derived field as the map only contains the
* current mapping. */
String fieldName = m_mapper.getDerivedFieldName(name);
df.setName(m_mapper.createDerivedFieldName(name));
df.setOptype(OPTYPE.CONTINUOUS);
df.setDataType(DATATYPE.DOUBLE);
NormContinuous cont = df.addNewNormContinuous();
cont.setField(fieldName);
double trans = m_affineTrans.getTranslations()[i];
double scale = m_affineTrans.getScales()[i];
LinearNorm firstNorm = cont.addNewLinearNorm();
firstNorm.setOrig(0.0);
firstNorm.setNorm(trans);
LinearNorm secondNorm = cont.addNewLinearNorm();
secondNorm.setOrig(1.0);
secondNorm.setNorm(scale + trans);
derivedFields[i] = df;
}
return derivedFields;
}
use of org.dmg.pmml.NormContinuousDocument.NormContinuous in project knime-core by knime.
the class DataColumnSpecFilterPMMLNodeModel method createPMMLOut.
private PMMLPortObject createPMMLOut(final PMMLPortObject pmmlIn, final DataTableSpec outSpec, final FilterResult res) throws XmlException {
StringBuffer warningBuffer = new StringBuffer();
if (pmmlIn == null) {
return new PMMLPortObject(createPMMLSpec(null, outSpec, res));
} else {
PMMLDocument pmmldoc;
try (LockedSupplier<Document> supplier = pmmlIn.getPMMLValue().getDocumentSupplier()) {
pmmldoc = PMMLDocument.Factory.parse(supplier.get());
}
// Inspect models to check if they use any excluded columns
List<PMMLModelWrapper> models = PMMLModelWrapper.getModelListFromPMMLDocument(pmmldoc);
for (PMMLModelWrapper model : models) {
MiningSchema ms = model.getMiningSchema();
for (MiningField mf : ms.getMiningFieldList()) {
if (isExcluded(mf.getName(), res)) {
if (warningBuffer.length() != 0) {
warningBuffer.append("\n");
}
warningBuffer.append(model.getModelType().name() + " uses excluded column " + mf.getName());
}
}
}
ArrayList<String> warningFields = new ArrayList<String>();
PMML pmml = pmmldoc.getPMML();
// Now check the transformations if they exist
if (pmml.getTransformationDictionary() != null) {
for (DerivedField df : pmml.getTransformationDictionary().getDerivedFieldList()) {
FieldRef fr = df.getFieldRef();
if (fr != null && isExcluded(fr.getField(), res)) {
warningFields.add(fr.getField());
}
Aggregate a = df.getAggregate();
if (a != null && isExcluded(a.getField(), res)) {
warningFields.add(a.getField());
}
Apply ap = df.getApply();
if (ap != null) {
for (FieldRef fieldRef : ap.getFieldRefList()) {
if (isExcluded(fieldRef.getField(), res)) {
warningFields.add(fieldRef.getField());
break;
}
}
}
Discretize d = df.getDiscretize();
if (d != null && isExcluded(d.getField(), res)) {
warningFields.add(d.getField());
}
MapValues mv = df.getMapValues();
if (mv != null) {
for (FieldColumnPair fcp : mv.getFieldColumnPairList()) {
if (isExcluded(fcp.getField(), res)) {
warningFields.add(fcp.getField());
}
}
}
NormContinuous nc = df.getNormContinuous();
if (nc != null && isExcluded(nc.getField(), res)) {
warningFields.add(nc.getField());
}
NormDiscrete nd = df.getNormDiscrete();
if (nd != null && isExcluded(nd.getField(), res)) {
warningFields.add(nd.getField());
}
}
}
DataDictionary dict = pmml.getDataDictionary();
List<DataField> fields = dict.getDataFieldList();
// Apply filter to spec
int numFields = 0;
for (int i = fields.size() - 1; i >= 0; i--) {
if (isExcluded(fields.get(i).getName(), res)) {
dict.removeDataField(i);
} else {
numFields++;
}
}
dict.setNumberOfFields(new BigInteger(Integer.toString(numFields)));
pmml.setDataDictionary(dict);
pmmldoc.setPMML(pmml);
// generate warnings and set as warning message
for (String s : warningFields) {
if (warningBuffer.length() != 0) {
warningBuffer.append("\n");
}
warningBuffer.append("Transformation dictionary uses excluded column " + s);
}
if (warningBuffer.length() > 0) {
setWarningMessage(warningBuffer.toString().trim());
}
PMMLPortObject outport = null;
try {
outport = new PMMLPortObject(createPMMLSpec(pmmlIn.getSpec(), outSpec, res), pmmldoc);
} catch (IllegalArgumentException e) {
if (res.getIncludes().length == 0) {
throw new IllegalArgumentException("Excluding all columns produces invalid PMML", e);
} else {
throw e;
}
}
return outport;
}
}
use of org.dmg.pmml.NormContinuousDocument.NormContinuous in project knime-core by knime.
the class PMMLNormalizeTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public List<Integer> initializeFrom(final DerivedField[] derivedFields) {
if (derivedFields == null) {
return Collections.EMPTY_LIST;
}
m_mapper = new DerivedFieldMapper(derivedFields);
int num = derivedFields.length;
List<Integer> consumed = new ArrayList<Integer>(num);
if (num > 0) {
parseExtensionArray(derivedFields[0].getExtensionArray());
}
for (int i = 0; i < derivedFields.length; i++) {
DerivedField df = derivedFields[i];
/**
* This field contains the name of the column in KNIME that
* corresponds to the derived field in PMML. This is necessary if
* derived fields are defined on other derived fields and the
* columns in KNIME are replaced with the preprocessed values.
* In this case KNIME has to know the original names (e.g. A) while
* PMML references to A', A'' etc.
*/
String displayName = df.getDisplayName();
if (!df.isSetNormContinuous()) {
// only reading norm continuous other entries are skipped
continue;
}
consumed.add(i);
NormContinuous normContinuous = df.getNormContinuous();
if (normContinuous.getLinearNormArray().length > 2) {
throw new IllegalArgumentException("Only two LinearNorm " + "elements are supported per NormContinuous");
}
// String field = normContinuous.getField();
double[] orig = new double[MAX_NUM_SEGMENTS];
double[] norm = new double[MAX_NUM_SEGMENTS];
LinearNorm[] norms = normContinuous.getLinearNormArray();
for (int j = 0; j < norms.length; j++) {
orig[j] = norms[j].getOrig();
norm[j] = norms[j].getNorm();
}
double scale = (norm[1] - norm[0]) / (orig[1] - orig[0]);
m_scales.add(scale);
m_translations.add(norm[0] - scale * orig[0]);
if (displayName != null) {
m_fields.add(displayName);
} else {
m_fields.add(m_mapper.getColumnName(normContinuous.getField()));
}
}
return consumed;
}
Aggregations