use of org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork in project knime-core by knime.
the class PMMLNeuralNetworkTranslator method initializeFrom.
/**
* {@inheritDoc}
*/
@Override
public void initializeFrom(final PMMLDocument pmmlDoc) {
m_nameMapper = new DerivedFieldMapper(pmmlDoc);
NeuralNetwork[] models = pmmlDoc.getPMML().getNeuralNetworkArray();
if (models.length == 0) {
throw new IllegalArgumentException("No neural network model" + " provided.");
} else if (models.length > 1) {
LOGGER.warn("Multiple neural network models found. " + "Only the first model is considered.");
}
NeuralNetwork nnModel = models[0];
// ------------------------------
// initiate Neural Input
initInputLayer(nnModel);
// -------------------------------
// initiate Hidden Layer
initiateHiddenLayers(nnModel);
// -------------------------------
// initiate Final Layer
initiateFinalLayer(nnModel);
// --------------------------------
// initiate Neural Outputs
initiateNeuralOutputs(nnModel);
// --------------------------------
// initiate Neural Network properties
ACTIVATIONFUNCTION.Enum actFunc = nnModel.getActivationFunction();
NNNORMALIZATIONMETHOD.Enum normMethod = nnModel.getNormalizationMethod();
if (ACTIVATIONFUNCTION.LOGISTIC != actFunc) {
LOGGER.error("Only logistic activation function is " + "supported in KNIME MLP.");
}
if (NNNORMALIZATIONMETHOD.NONE != normMethod) {
LOGGER.error("No normalization method is " + "supported in KNIME MLP.");
}
MININGFUNCTION.Enum functionName = nnModel.getFunctionName();
if (MININGFUNCTION.CLASSIFICATION == functionName) {
m_mlpMethod = MultiLayerPerceptron.CLASSIFICATION_MODE;
} else if (MININGFUNCTION.REGRESSION == functionName) {
m_mlpMethod = MultiLayerPerceptron.REGRESSION_MODE;
}
if (m_allLayers.size() < 3) {
throw new IllegalArgumentException("Only neural networks with 3 Layers supported in KNIME MLP.");
}
Layer[] allLayers = new Layer[m_allLayers.size()];
allLayers = m_allLayers.toArray(allLayers);
m_mlp = new MultiLayerPerceptron(allLayers);
Architecture myarch = new Architecture(allLayers[0].getPerceptrons().length, allLayers.length - 2, allLayers[1].getPerceptrons().length, allLayers[allLayers.length - 1].getPerceptrons().length);
m_mlp.setArchitecture(myarch);
m_mlp.setClassMapping(m_classmap);
m_mlp.setInputMapping(m_inputmap);
m_mlp.setMode(m_mlpMethod);
}
use of org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork in project knime-core by knime.
the class PMMLUtils method getFirstMiningSchema.
/**
* Retrieves the mining schema of the first model of a specific type.
*
* @param pmmlDoc the PMML document to extract the mining schema from
* @param type the type of the model
* @return the mining schema of the first model of the given type or null if
* there is no model of the given type contained in the pmmlDoc
*/
public static MiningSchema getFirstMiningSchema(final PMMLDocument pmmlDoc, final SchemaType type) {
Map<PMMLModelType, Integer> models = getNumberOfModels(pmmlDoc);
if (!models.containsKey(PMMLModelType.getType(type))) {
return null;
}
PMML pmml = pmmlDoc.getPMML();
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
return model.getMiningSchema();
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
return model.getMiningSchema();
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
return model.getMiningSchema();
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
return model.getMiningSchema();
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
return model.getMiningSchema();
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
return model.getMiningSchema();
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
return model.getMiningSchema();
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
return model.getMiningSchema();
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
return model.getMiningSchema();
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
return model.getMiningSchema();
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
return model.getMiningSchema();
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
return model.getMiningSchema();
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
return model.getMiningSchema();
} else {
return null;
}
}
use of org.dmg.pmml.NeuralNetworkDocument.NeuralNetwork in project knime-core by knime.
the class PMMLPortObject method moveDerivedFields.
/**
* Moves the content of the transformation dictionary to local
* transformations.
* @param type the type of model to move the derived fields to
* @return the {@link LocalTransformations} element containing the moved
* derived fields or an empty local transformation object if nothing
* has to be moved
*/
private LocalTransformations moveDerivedFields(final SchemaType type) {
PMML pmml = m_pmmlDoc.getPMML();
TransformationDictionary transDict = pmml.getTransformationDictionary();
LocalTransformations localTrans = LocalTransformations.Factory.newInstance();
if (transDict == null) {
// nothing to be moved
return localTrans;
}
localTrans.setDerivedFieldArray(transDict.getDerivedFieldArray());
localTrans.setExtensionArray(transDict.getExtensionArray());
/*
* Unfortunately the PMML models have no common base class. Therefore a
* cast to the specific type is necessary for being able to add the
* mining schema.
*/
boolean known = true;
if (AssociationModel.type.equals(type)) {
AssociationModel model = pmml.getAssociationModelArray(0);
model.setLocalTransformations(localTrans);
} else if (ClusteringModel.type.equals(type)) {
ClusteringModel model = pmml.getClusteringModelArray(0);
model.setLocalTransformations(localTrans);
} else if (GeneralRegressionModel.type.equals(type)) {
GeneralRegressionModel model = pmml.getGeneralRegressionModelArray(0);
model.setLocalTransformations(localTrans);
} else if (MiningModel.type.equals(type)) {
MiningModel model = pmml.getMiningModelArray(0);
model.setLocalTransformations(localTrans);
} else if (NaiveBayesModel.type.equals(type)) {
NaiveBayesModel model = pmml.getNaiveBayesModelArray(0);
model.setLocalTransformations(localTrans);
} else if (NeuralNetwork.type.equals(type)) {
NeuralNetwork model = pmml.getNeuralNetworkArray(0);
model.setLocalTransformations(localTrans);
} else if (RegressionModel.type.equals(type)) {
RegressionModel model = pmml.getRegressionModelArray(0);
model.setLocalTransformations(localTrans);
} else if (RuleSetModel.type.equals(type)) {
RuleSetModel model = pmml.getRuleSetModelArray(0);
model.setLocalTransformations(localTrans);
} else if (SequenceModel.type.equals(type)) {
SequenceModel model = pmml.getSequenceModelArray(0);
model.setLocalTransformations(localTrans);
} else if (SupportVectorMachineModel.type.equals(type)) {
SupportVectorMachineModel model = pmml.getSupportVectorMachineModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TextModel.type.equals(type)) {
TextModel model = pmml.getTextModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TimeSeriesModel.type.equals(type)) {
TimeSeriesModel model = pmml.getTimeSeriesModelArray(0);
model.setLocalTransformations(localTrans);
} else if (TreeModel.type.equals(type)) {
TreeModel model = pmml.getTreeModelArray(0);
model.setLocalTransformations(localTrans);
} else {
if (type != null) {
LOGGER.error("Could not move TransformationDictionary to " + "unsupported model of type \"" + type + "\".");
}
known = false;
}
if (known) {
// remove derived fields from TransformationDictionary
transDict.setDerivedFieldArray(new DerivedField[0]);
transDict.setExtensionArray(new ExtensionDocument.Extension[0]);
}
return localTrans;
}
Aggregations