Search in sources :

Example 6 with NeuralLayer

use of org.dmg.pmml.neural_network.NeuralLayer in project jpmml-r by jpmml.

the class NNetConverter method encodeModel.

@Override
public Model encodeModel(Schema schema) {
    RGenericVector nnet = getObject();
    RDoubleVector n = nnet.getDoubleElement("n");
    RBooleanVector linout = nnet.getBooleanElement("linout", false);
    RBooleanVector softmax = nnet.getBooleanElement("softmax", false);
    RBooleanVector censored = nnet.getBooleanElement("censored", false);
    RDoubleVector wts = nnet.getDoubleElement("wts");
    RStringVector lev = nnet.getStringElement("lev", false);
    if (n.size() != 3) {
        throw new IllegalArgumentException();
    }
    Label label = schema.getLabel();
    List<? extends Feature> features = schema.getFeatures();
    MiningFunction miningFunction;
    if (lev == null) {
        if (linout != null && !linout.asScalar()) {
            throw new IllegalArgumentException();
        }
        miningFunction = MiningFunction.REGRESSION;
    } else {
        miningFunction = MiningFunction.CLASSIFICATION;
    }
    int nInput = ValueUtil.asInt(n.getValue(0));
    SchemaUtil.checkSize(nInput, features);
    NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
    int offset = 0;
    List<NeuralLayer> neuralLayers = new ArrayList<>();
    List<? extends NeuralEntity> entities = neuralInputs.getNeuralInputs();
    int nHidden = ValueUtil.asInt(n.getValue(1));
    if (nHidden > 0) {
        NeuralLayer neuralLayer = encodeNeuralLayer("hidden", nHidden, entities, wts, offset).setActivationFunction(NeuralNetwork.ActivationFunction.LOGISTIC);
        offset += (nHidden * (entities.size() + 1));
        neuralLayers.add(neuralLayer);
        entities = neuralLayer.getNeurons();
    }
    int nOutput = ValueUtil.asInt(n.getValue(2));
    if (nOutput == 1) {
        NeuralLayer neuralLayer = encodeNeuralLayer("output", nOutput, entities, wts, offset);
        offset += (nOutput * (entities.size() + 1));
        neuralLayers.add(neuralLayer);
        entities = neuralLayer.getNeurons();
        switch(miningFunction) {
            case REGRESSION:
                break;
            case CLASSIFICATION:
                {
                    List<NeuralLayer> transformationNeuralLayers = NeuralNetworkUtil.createBinaryLogisticTransformation(Iterables.getOnlyElement(entities));
                    neuralLayers.addAll(transformationNeuralLayers);
                    neuralLayer = Iterables.getLast(transformationNeuralLayers);
                    entities = neuralLayer.getNeurons();
                }
                break;
        }
    } else if (nOutput > 1) {
        NeuralLayer neuralLayer = encodeNeuralLayer("output", nOutput, entities, wts, offset);
        if (softmax != null && softmax.asScalar()) {
            if (censored != null && censored.asScalar()) {
                throw new IllegalArgumentException();
            }
            neuralLayer.setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
        }
        offset += (nOutput * (entities.size() + 1));
        neuralLayers.add(neuralLayer);
        entities = neuralLayer.getNeurons();
    } else {
        throw new IllegalArgumentException();
    }
    NeuralNetwork neuralNetwork = new NeuralNetwork(miningFunction, NeuralNetwork.ActivationFunction.IDENTITY, ModelUtil.createMiningSchema(label), neuralInputs, neuralLayers);
    switch(miningFunction) {
        case REGRESSION:
            neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createRegressionNeuralOutputs(entities, (ContinuousLabel) label));
            break;
        case CLASSIFICATION:
            neuralNetwork.setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, (CategoricalLabel) label)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, (CategoricalLabel) label));
            break;
    }
    return neuralNetwork;
}
Also used : NeuralInputs(org.dmg.pmml.neural_network.NeuralInputs) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ContinuousLabel(org.jpmml.converter.ContinuousLabel) Label(org.jpmml.converter.Label) ArrayList(java.util.ArrayList) NeuralLayer(org.dmg.pmml.neural_network.NeuralLayer) NeuralNetwork(org.dmg.pmml.neural_network.NeuralNetwork) CategoricalLabel(org.jpmml.converter.CategoricalLabel) ArrayList(java.util.ArrayList) List(java.util.List) MiningFunction(org.dmg.pmml.MiningFunction) ContinuousLabel(org.jpmml.converter.ContinuousLabel)

Aggregations

NeuralLayer (org.dmg.pmml.neural_network.NeuralLayer)6 ArrayList (java.util.ArrayList)5 Neuron (org.dmg.pmml.neural_network.Neuron)5 NeuralInputs (org.dmg.pmml.neural_network.NeuralInputs)4 NeuralNetwork (org.dmg.pmml.neural_network.NeuralNetwork)4 ContinuousLabel (org.jpmml.converter.ContinuousLabel)3 List (java.util.List)2 CategoricalLabel (org.jpmml.converter.CategoricalLabel)2 Label (org.jpmml.converter.Label)2 MultilayerPerceptronClassificationModel (org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel)1 Vector (org.apache.spark.ml.linalg.Vector)1 MiningFunction (org.dmg.pmml.MiningFunction)1 Connection (org.dmg.pmml.neural_network.Connection)1 ActivationFunction (org.dmg.pmml.neural_network.NeuralNetwork.ActivationFunction)1 NeuralOutputs (org.dmg.pmml.neural_network.NeuralOutputs)1