Search in sources :

Example 1 with MultilayerPerceptronClassificationModel

use of org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel in project jpmml-sparkml by jpmml.

the class MultilayerPerceptronClassificationModelConverter method registerOutputFields.

@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
    MultilayerPerceptronClassificationModel model = getTransformer();
    List<OutputField> result = super.registerOutputFields(label, encoder);
    if (!(model instanceof HasProbabilityCol)) {
        CategoricalLabel categoricalLabel = (CategoricalLabel) label;
        result = new ArrayList<>(result);
        result.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()));
    }
    return result;
}
Also used : HasProbabilityCol(org.apache.spark.ml.param.shared.HasProbabilityCol) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MultilayerPerceptronClassificationModel(org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel) OutputField(org.dmg.pmml.OutputField)

Example 2 with MultilayerPerceptronClassificationModel

use of org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel in project jpmml-sparkml by jpmml.

the class MultilayerPerceptronClassificationModelConverter method encodeModel.

@Override
public NeuralNetwork encodeModel(Schema schema) {
    MultilayerPerceptronClassificationModel model = getTransformer();
    int[] layers = model.layers();
    Vector weights = model.weights();
    CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
    if (categoricalLabel.size() != layers[layers.length - 1]) {
        throw new IllegalArgumentException();
    }
    List<? extends Feature> features = schema.getFeatures();
    if (features.size() != layers[0]) {
        throw new IllegalArgumentException();
    }
    NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
    List<? extends Entity> entities = neuralInputs.getNeuralInputs();
    List<NeuralLayer> neuralLayers = new ArrayList<>();
    int weightPos = 0;
    for (int layer = 1; layer < layers.length; layer++) {
        NeuralLayer neuralLayer = new NeuralLayer();
        int rows = entities.size();
        int columns = layers[layer];
        List<List<Double>> weightMatrix = new ArrayList<>();
        for (int column = 0; column < columns; column++) {
            List<Double> weightVector = new ArrayList<>();
            for (int row = 0; row < rows; row++) {
                weightVector.add(weights.apply(weightPos + (row * columns) + column));
            }
            weightMatrix.add(weightVector);
        }
        weightPos += (rows * columns);
        for (int column = 0; column < columns; column++) {
            List<Double> weightVector = weightMatrix.get(column);
            Double bias = weights.apply(weightPos);
            Neuron neuron = NeuralNetworkUtil.createNeuron(entities, weightVector, bias).setId(String.valueOf(layer) + "/" + String.valueOf(column + 1));
            neuralLayer.addNeurons(neuron);
            weightPos++;
        }
        if (layer == (layers.length - 1)) {
            neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY).setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
        }
        neuralLayers.add(neuralLayer);
        entities = neuralLayer.getNeurons();
    }
    if (weightPos != weights.size()) {
        throw new IllegalArgumentException();
    }
    NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema(categoricalLabel), neuralInputs, neuralLayers).setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, categoricalLabel));
    return neuralNetwork;
}
Also used : NeuralInputs(org.dmg.pmml.neural_network.NeuralInputs) ArrayList(java.util.ArrayList) NeuralLayer(org.dmg.pmml.neural_network.NeuralLayer) NeuralNetwork(org.dmg.pmml.neural_network.NeuralNetwork) Neuron(org.dmg.pmml.neural_network.Neuron) CategoricalLabel(org.jpmml.converter.CategoricalLabel) MultilayerPerceptronClassificationModel(org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel) ArrayList(java.util.ArrayList) List(java.util.List) Vector(org.apache.spark.ml.linalg.Vector)

Aggregations

MultilayerPerceptronClassificationModel (org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel)2 CategoricalLabel (org.jpmml.converter.CategoricalLabel)2 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Vector (org.apache.spark.ml.linalg.Vector)1 HasProbabilityCol (org.apache.spark.ml.param.shared.HasProbabilityCol)1 OutputField (org.dmg.pmml.OutputField)1 NeuralInputs (org.dmg.pmml.neural_network.NeuralInputs)1 NeuralLayer (org.dmg.pmml.neural_network.NeuralLayer)1 NeuralNetwork (org.dmg.pmml.neural_network.NeuralNetwork)1 Neuron (org.dmg.pmml.neural_network.Neuron)1