Search in sources :

Example 36 with DerivedField

use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.

the class PMMLNeuralNetworkTranslator method addOutputLayer.

/**
 * Writes the PMML output layer of the MLP.
 *
 * @param nnModel
 *            the neural network model.
 * @param mlp
 *            the underlying {@link MultiLayerPerceptron}.
 * @param spec
 *            the port object spec
 */
protected void addOutputLayer(final NeuralNetwork nnModel, final MultiLayerPerceptron mlp, final PMMLPortObjectSpec spec) {
    int lastlayer = mlp.getNrLayers() - 1;
    String targetCol = spec.getTargetFields().iterator().next();
    Layer outputlayer = mlp.getLayer(lastlayer);
    Perceptron[] outputperceptrons = outputlayer.getPerceptrons();
    HashMap<DataCell, Integer> outputmap = mlp.getClassMapping();
    NeuralOutputs neuralOuts = nnModel.addNewNeuralOutputs();
    neuralOuts.setNumberOfOutputs(BigInteger.valueOf(outputperceptrons.length));
    for (int i = 0; i < outputperceptrons.length; i++) {
        NeuralOutput neuralOutput = neuralOuts.addNewNeuralOutput();
        neuralOutput.setOutputNeuron(lastlayer + "," + i);
        // search corresponding output value
        String colname = "";
        for (Entry<DataCell, Integer> e : outputmap.entrySet()) {
            if (e.getValue().equals(i)) {
                colname = ((StringValue) e.getKey()).getStringValue();
            }
        }
        DerivedField df = neuralOutput.addNewDerivedField();
        df.setOptype(OPTYPE.CATEGORICAL);
        df.setDataType(DATATYPE.STRING);
        if (mlp.getMode() == MultiLayerPerceptron.CLASSIFICATION_MODE) {
            df.setOptype(OPTYPE.CATEGORICAL);
            df.setDataType(DATATYPE.STRING);
        } else if (mlp.getMode() == MultiLayerPerceptron.REGRESSION_MODE) {
            df.setOptype(OPTYPE.CONTINUOUS);
            df.setDataType(DATATYPE.DOUBLE);
        }
        if (mlp.getMode() == MultiLayerPerceptron.CLASSIFICATION_MODE) {
            NormDiscrete normDiscrete = df.addNewNormDiscrete();
            normDiscrete.setField(targetCol);
            normDiscrete.setValue(colname);
        } else if (mlp.getMode() == MultiLayerPerceptron.REGRESSION_MODE) {
            FieldRef fieldRef = df.addNewFieldRef();
            fieldRef.setField(targetCol);
        }
    }
}
Also used : NeuralOutputs(org.dmg.pmml.NeuralOutputsDocument.NeuralOutputs) FieldRef(org.dmg.pmml.FieldRefDocument.FieldRef) NeuralLayer(org.dmg.pmml.NeuralLayerDocument.NeuralLayer) Layer(org.knime.base.data.neural.Layer) InputLayer(org.knime.base.data.neural.InputLayer) HiddenLayer(org.knime.base.data.neural.HiddenLayer) NeuralOutput(org.dmg.pmml.NeuralOutputDocument.NeuralOutput) BigInteger(java.math.BigInteger) NormDiscrete(org.dmg.pmml.NormDiscreteDocument.NormDiscrete) SigmoidPerceptron(org.knime.base.data.neural.SigmoidPerceptron) MultiLayerPerceptron(org.knime.base.data.neural.MultiLayerPerceptron) Perceptron(org.knime.base.data.neural.Perceptron) InputPerceptron(org.knime.base.data.neural.InputPerceptron) DataCell(org.knime.core.data.DataCell) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 37 with DerivedField

use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.

the class PMMLNeuralNetworkTranslator method initiateNeuralOutputs.

/**
 * @param nnModel
 *            the PMML neural network model
 */
private void initiateNeuralOutputs(final NeuralNetwork nnModel) {
    NeuralOutputs neuralOutputs = nnModel.getNeuralOutputs();
    m_classmap = new HashMap<DataCell, Integer>();
    for (NeuralOutput no : neuralOutputs.getNeuralOutputArray()) {
        m_curPercpetronID = no.getOutputNeuron();
        DerivedField df = no.getDerivedField();
        if (df.isSetNormDiscrete()) {
            String value = df.getNormDiscrete().getValue();
            int pos = m_idPosMap.get(m_curPercpetronID);
            m_classmap.put(new StringCell(value), pos);
        } else if (df.isSetFieldRef()) {
            int pos = m_idPosMap.get(m_curPercpetronID);
            m_classmap.put(new StringCell(df.getFieldRef().getField()), pos);
        } else {
            LOGGER.error("The expression is not supported in KNIME MLP.");
        }
    }
}
Also used : NeuralOutputs(org.dmg.pmml.NeuralOutputsDocument.NeuralOutputs) BigInteger(java.math.BigInteger) StringCell(org.knime.core.data.def.StringCell) DataCell(org.knime.core.data.DataCell) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField) NeuralOutput(org.dmg.pmml.NeuralOutputDocument.NeuralOutput)

Example 38 with DerivedField

use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.

the class DerivedFieldMapper method processFields.

private void processFields(final DerivedField[] derivedFields) {
    if (derivedFields == null) {
        return;
    }
    for (DerivedField df : derivedFields) {
        if (df.getDisplayName() == null) {
            continue;
        }
        /*
             * If multiple operations are performed on the same column there
             * will be multiple fields with the same display name (column name).
             * In this case the last one is relevant because the operations are
             * inserted in order. Putting each name in and overriding previous
             * mappings does this job.
             */
        m_derivedNames.put(df.getDisplayName(), df.getName());
        m_colNames.put(df.getName(), df.getDisplayName());
    }
}
Also used : DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 39 with DerivedField

use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.

the class DerivedFieldMapper method getDerivedFields.

/**
 * @param pmml the pmml document to retrieve the derived fields from
 * @return all derived fields from the transformation dictionary as well as
 *      of all local transformation elements, or an empty array if no
 *      derived fields are defined.
 */
public static DerivedField[] getDerivedFields(final PMML pmml) {
    List<DerivedField> derivedFields = new ArrayList<DerivedField>();
    TransformationDictionary trans = pmml.getTransformationDictionary();
    if (trans != null) {
        derivedFields.addAll(Arrays.asList(trans.getDerivedFieldArray()));
    }
    LocalTransformations localTrans = null;
    if (pmml.getAssociationModelArray().length > 0) {
        localTrans = pmml.getAssociationModelArray(0).getLocalTransformations();
    } else if (pmml.getClusteringModelArray().length > 0) {
        localTrans = pmml.getClusteringModelArray(0).getLocalTransformations();
    } else if (pmml.getGeneralRegressionModelArray().length > 0) {
        localTrans = pmml.getGeneralRegressionModelArray(0).getLocalTransformations();
    } else if (pmml.getNaiveBayesModelArray().length > 0) {
        localTrans = pmml.getNaiveBayesModelArray(0).getLocalTransformations();
    } else if (pmml.getNeuralNetworkArray().length > 0) {
        localTrans = pmml.getNeuralNetworkArray(0).getLocalTransformations();
    } else if (pmml.getRegressionModelArray().length > 0) {
        localTrans = pmml.getRegressionModelArray(0).getLocalTransformations();
    } else if (pmml.getRuleSetModelArray().length > 0) {
        localTrans = pmml.getRuleSetModelArray(0).getLocalTransformations();
    } else if (pmml.getSequenceModelArray().length > 0) {
        localTrans = pmml.getSequenceModelArray(0).getLocalTransformations();
    } else if (pmml.getSupportVectorMachineModelArray().length > 0) {
        localTrans = pmml.getSupportVectorMachineModelArray(0).getLocalTransformations();
    } else if (pmml.getTextModelArray().length > 0) {
        localTrans = pmml.getTextModelArray(0).getLocalTransformations();
    } else if (pmml.getTimeSeriesModelArray().length > 0) {
        localTrans = pmml.getTimeSeriesModelArray(0).getLocalTransformations();
    } else if (pmml.getTreeModelArray().length > 0) {
        localTrans = pmml.getTreeModelArray(0).getLocalTransformations();
    } else if (pmml.sizeOfRuleSetModelArray() > 0) {
        localTrans = pmml.getRuleSetModelArray(0).getLocalTransformations();
    }
    if (localTrans != null) {
        derivedFields.addAll(Arrays.asList(localTrans.getDerivedFieldArray()));
    }
    return derivedFields.toArray(new DerivedField[0]);
}
Also used : LocalTransformations(org.dmg.pmml.LocalTransformationsDocument.LocalTransformations) TransformationDictionary(org.dmg.pmml.TransformationDictionaryDocument.TransformationDictionary) ArrayList(java.util.ArrayList) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField)

Example 40 with DerivedField

use of org.dmg.pmml.DerivedFieldDocument.DerivedField in project knime-core by knime.

the class PMMLBinningTranslator method initializeFrom.

@Override
public List<Integer> initializeFrom(final DerivedField[] derivedFields) {
    m_mapper = new DerivedFieldMapper(derivedFields);
    final List<Integer> consumed = new ArrayList<>(derivedFields.length);
    for (int i = 0; i < derivedFields.length; i++) {
        final DerivedField df = derivedFields[i];
        if (!df.isSetDiscretize()) {
            // only reading discretize entries other entries are skipped
            continue;
        }
        consumed.add(i);
        final Discretize discretize = df.getDiscretize();
        @SuppressWarnings("deprecation") final DiscretizeBin[] pmmlBins = discretize.getDiscretizeBinArray();
        final NumericBin[] knimeBins = new NumericBin[pmmlBins.length];
        for (int j = 0; j < pmmlBins.length; j++) {
            final DiscretizeBin bin = pmmlBins[j];
            final String binName = bin.getBinValue();
            final Interval interval = bin.getInterval();
            final double leftValue = interval.getLeftMargin();
            final double rightValue = interval.getRightMargin();
            final Closure.Enum closure = interval.getClosure();
            boolean leftOpen = true;
            boolean rightOpen = true;
            if (Closure.OPEN_CLOSED == closure) {
                rightOpen = false;
            } else if (Closure.CLOSED_OPEN == closure) {
                leftOpen = false;
            } else if (Closure.CLOSED_CLOSED == closure) {
                leftOpen = false;
                rightOpen = false;
            }
            knimeBins[j] = new NumericBin(binName, leftOpen, leftValue, rightOpen, rightValue);
        }
        /**
         * This field contains the name of the column in KNIME that corresponds to the derived field in PMML. This
         * is necessary if derived fields are defined on other derived fields and the columns in KNIME are replaced
         * with the preprocessed values. In this case KNIME has to know the original names (e.g. A) while PMML
         * references to A', A'' etc.
         */
        final String displayName = df.getDisplayName();
        if (displayName != null) {
            m_columnToBins.put(displayName, knimeBins);
            m_columnToAppend.put(displayName, null);
        } else if (df.getName() != null) {
            final String field = m_mapper.getColumnName(discretize.getField());
            m_columnToBins.put(field, knimeBins);
            m_columnToAppend.put(field, df.getName());
        }
    }
    return consumed;
}
Also used : Closure(org.dmg.pmml.IntervalDocument.Interval.Closure) ArrayList(java.util.ArrayList) DerivedFieldMapper(org.knime.core.node.port.pmml.preproc.DerivedFieldMapper) Discretize(org.dmg.pmml.DiscretizeDocument.Discretize) DiscretizeBin(org.dmg.pmml.DiscretizeBinDocument.DiscretizeBin) DerivedField(org.dmg.pmml.DerivedFieldDocument.DerivedField) Interval(org.dmg.pmml.IntervalDocument.Interval)

Aggregations

DerivedField (org.dmg.pmml.DerivedFieldDocument.DerivedField)41 ArrayList (java.util.ArrayList)12 FieldRef (org.dmg.pmml.FieldRefDocument.FieldRef)11 BigInteger (java.math.BigInteger)9 DerivedFieldMapper (org.knime.core.node.port.pmml.preproc.DerivedFieldMapper)8 MapValues (org.dmg.pmml.MapValuesDocument.MapValues)7 DataColumnSpec (org.knime.core.data.DataColumnSpec)6 Apply (org.dmg.pmml.ApplyDocument.Apply)5 DiscretizeBin (org.dmg.pmml.DiscretizeBinDocument.DiscretizeBin)5 Discretize (org.dmg.pmml.DiscretizeDocument.Discretize)5 Interval (org.dmg.pmml.IntervalDocument.Interval)5 NormDiscrete (org.dmg.pmml.NormDiscreteDocument.NormDiscrete)5 DataCell (org.knime.core.data.DataCell)5 DataType (org.knime.core.data.DataType)5 LinkedHashMap (java.util.LinkedHashMap)4 LinkedHashSet (java.util.LinkedHashSet)4 Map (java.util.Map)4 LocalTransformations (org.dmg.pmml.LocalTransformationsDocument.LocalTransformations)4 NeuralLayer (org.dmg.pmml.NeuralLayerDocument.NeuralLayer)4 NeuralOutput (org.dmg.pmml.NeuralOutputDocument.NeuralOutput)4