use of org.dmg.pmml.neural_network.NeuralLayer in project jpmml-r by jpmml.
the class NNetConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector nnet = getObject();
RDoubleVector n = nnet.getDoubleElement("n");
RBooleanVector linout = nnet.getBooleanElement("linout", false);
RBooleanVector softmax = nnet.getBooleanElement("softmax", false);
RBooleanVector censored = nnet.getBooleanElement("censored", false);
RDoubleVector wts = nnet.getDoubleElement("wts");
RStringVector lev = nnet.getStringElement("lev", false);
if (n.size() != 3) {
throw new IllegalArgumentException();
}
Label label = schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
MiningFunction miningFunction;
if (lev == null) {
if (linout != null && !linout.asScalar()) {
throw new IllegalArgumentException();
}
miningFunction = MiningFunction.REGRESSION;
} else {
miningFunction = MiningFunction.CLASSIFICATION;
}
int nInput = ValueUtil.asInt(n.getValue(0));
SchemaUtil.checkSize(nInput, features);
NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
int offset = 0;
List<NeuralLayer> neuralLayers = new ArrayList<>();
List<? extends NeuralEntity> entities = neuralInputs.getNeuralInputs();
int nHidden = ValueUtil.asInt(n.getValue(1));
if (nHidden > 0) {
NeuralLayer neuralLayer = encodeNeuralLayer("hidden", nHidden, entities, wts, offset).setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
offset += (nHidden * (entities.size() + 1));
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
}
int nOutput = ValueUtil.asInt(n.getValue(2));
if (nOutput == 1) {
NeuralLayer neuralLayer = encodeNeuralLayer("output", nOutput, entities, wts, offset);
offset += (nOutput * (entities.size() + 1));
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
switch(miningFunction) {
case REGRESSION:
break;
case CLASSIFICATION:
{
List<NeuralLayer> transformationNeuralLayers = NeuralNetworkUtil.createBinaryLogisticTransformation(Iterables.getOnlyElement(entities));
neuralLayers.addAll(transformationNeuralLayers);
neuralLayer = Iterables.getLast(transformationNeuralLayers);
entities = neuralLayer.getNeurons();
}
break;
}
} else if (nOutput > 1) {
NeuralLayer neuralLayer = encodeNeuralLayer("output", nOutput, entities, wts, offset);
if (softmax != null && softmax.asScalar()) {
if (censored != null && censored.asScalar()) {
throw new IllegalArgumentException();
}
neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
}
offset += (nOutput * (entities.size() + 1));
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
} else {
throw new IllegalArgumentException();
}
NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema(label), neuralInputs, neuralLayers);
switch(miningFunction) {
case REGRESSION:
neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(entities, (ContinuousLabel) label));
break;
case CLASSIFICATION:
neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, (CategoricalLabel) label)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
break;
}
return neuralNetwork;
}
Aggregations