use of org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel in project jpmml-sparkml by jpmml.
the class MultilayerPerceptronClassificationModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
MultilayerPerceptronClassificationModel model = getTransformer();
List<OutputField> result = super.registerOutputFields(label, encoder);
if (!(model instanceof HasProbabilityCol)) {
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
result = new ArrayList<>(result);
result.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()));
}
return result;
}
use of org.apache.spark.ml.classification.MultilayerPerceptronClassificationModel in project jpmml-sparkml by jpmml.
the class MultilayerPerceptronClassificationModelConverter method encodeModel.
@Override
public NeuralNetwork encodeModel(Schema schema) {
MultilayerPerceptronClassificationModel model = getTransformer();
int[] layers = model.layers();
Vector weights = model.weights();
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != layers[layers.length - 1]) {
throw new IllegalArgumentException();
}
List<? extends Feature> features = schema.getFeatures();
if (features.size() != layers[0]) {
throw new IllegalArgumentException();
}
NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
List<? extends Entity> entities = neuralInputs.getNeuralInputs();
List<NeuralLayer> neuralLayers = new ArrayList<>();
int weightPos = 0;
for (int layer = 1; layer < layers.length; layer++) {
NeuralLayer neuralLayer = new NeuralLayer();
int rows = entities.size();
int columns = layers[layer];
List<List<Double>> weightMatrix = new ArrayList<>();
for (int column = 0; column < columns; column++) {
List<Double> weightVector = new ArrayList<>();
for (int row = 0; row < rows; row++) {
weightVector.add(weights.apply(weightPos + (row * columns) + column));
}
weightMatrix.add(weightVector);
}
weightPos += (rows * columns);
for (int column = 0; column < columns; column++) {
List<Double> weightVector = weightMatrix.get(column);
Double bias = weights.apply(weightPos);
Neuron neuron = NeuralNetworkUtil.createNeuron(entities, weightVector, bias).setId(String.valueOf(layer) + "/" + String.valueOf(column + 1));
neuralLayer.addNeurons(neuron);
weightPos++;
}
if (layer == (layers.length - 1)) {
neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY).setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
}
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
}
if (weightPos != weights.size()) {
throw new IllegalArgumentException();
}
NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema(categoricalLabel), neuralInputs, neuralLayers).setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, categoricalLabel));
return neuralNetwork;
}
Aggregations