use of org.jpmml.converter.CategoricalLabel in project jpmml-r by jpmml.
the class RandomForestConverter method encodeClassification.
private MiningModel encodeClassification(RGenericVector forest, final Schema schema) {
RNumberVector<?> bestvar = (RNumberVector<?>) forest.getValue("bestvar");
RNumberVector<?> treemap = (RNumberVector<?>) forest.getValue("treemap");
RIntegerVector nodepred = (RIntegerVector) forest.getValue("nodepred");
RDoubleVector xbestsplit = (RDoubleVector) forest.getValue("xbestsplit");
RIntegerVector nrnodes = (RIntegerVector) forest.getValue("nrnodes");
RDoubleVector ntree = (RDoubleVector) forest.getValue("ntree");
int rows = nrnodes.asScalar();
int columns = ValueUtil.asInt(ntree.asScalar());
final CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
ScoreEncoder<Integer> scoreEncoder = new ScoreEncoder<Integer>() {
@Override
public String encode(Integer value) {
return categoricalLabel.getValue(value - 1);
}
};
Schema segmentSchema = schema.toAnonymousSchema();
List<TreeModel> treeModels = new ArrayList<>();
for (int i = 0; i < columns; i++) {
List<? extends Number> daughters = FortranMatrixUtil.getColumn(treemap.getValues(), 2 * rows, columns, i);
TreeModel treeModel = encodeTreeModel(MiningFunction.CLASSIFICATION, scoreEncoder, FortranMatrixUtil.getColumn(daughters, rows, 2, 0), FortranMatrixUtil.getColumn(daughters, rows, 2, 1), FortranMatrixUtil.getColumn(nodepred.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(bestvar.getValues(), rows, columns, i), FortranMatrixUtil.getColumn(xbestsplit.getValues(), rows, columns, i), segmentSchema);
treeModels.add(treeModel);
}
MiningModel miningModel = new MiningModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel)).setSegmentation(MiningModelUtil.createSegmentation(Segmentation.MultipleModelMethod.MAJORITY_VOTE, treeModels)).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return miningModel;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-r by jpmml.
the class RExpEncoder method setLabel.
public void setLabel(DataField dataField) {
Label label;
OpType opType = dataField.getOpType();
switch(opType) {
case CATEGORICAL:
label = new CategoricalLabel(dataField);
break;
case CONTINUOUS:
label = new ContinuousLabel(dataField);
break;
default:
throw new IllegalArgumentException();
}
setLabel(label);
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method registerOutputFields.
@Override
public List<OutputField> registerOutputFields(Label label, SparkMLEncoder encoder) {
List<OutputField> result = super.registerOutputFields(label, encoder);
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) label;
result = new ArrayList<>(result);
result.addAll(ModelUtil.createProbabilityFields(DataType.DOUBLE, categoricalLabel.getValues()));
break;
default:
break;
}
return result;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class GeneralizedLinearRegressionModelConverter method encodeModel.
@Override
public GeneralRegressionModel encodeModel(Schema schema) {
GeneralizedLinearRegressionModel model = getTransformer();
String targetCategory = null;
MiningFunction miningFunction = getMiningFunction();
switch(miningFunction) {
case CLASSIFICATION:
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != 2) {
throw new IllegalArgumentException();
}
targetCategory = categoricalLabel.getValue(1);
break;
default:
break;
}
GeneralRegressionModel generalRegressionModel = new GeneralRegressionModel(GeneralRegressionModel.ModelType.GENERALIZED_LINEAR, miningFunction, ModelUtil.createMiningSchema(schema.getLabel()), null, null, null).setDistribution(parseFamily(model.getFamily())).setLinkFunction(parseLinkFunction(model.getLink())).setLinkParameter(parseLinkParameter(model.getLink()));
GeneralRegressionModelUtil.encodeRegressionTable(generalRegressionModel, schema.getFeatures(), model.intercept(), VectorUtil.toList(model.coefficients()), targetCategory);
return generalRegressionModel;
}
use of org.jpmml.converter.CategoricalLabel in project jpmml-sparkml by jpmml.
the class MultilayerPerceptronClassificationModelConverter method encodeModel.
@Override
public NeuralNetwork encodeModel(Schema schema) {
MultilayerPerceptronClassificationModel model = getTransformer();
int[] layers = model.layers();
Vector weights = model.weights();
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() != layers[layers.length - 1]) {
throw new IllegalArgumentException();
}
List<? extends Feature> features = schema.getFeatures();
if (features.size() != layers[0]) {
throw new IllegalArgumentException();
}
NeuralInputs neuralInputs = NeuralNetworkUtil.createNeuralInputs(features, DataType.DOUBLE);
List<? extends Entity> entities = neuralInputs.getNeuralInputs();
List<NeuralLayer> neuralLayers = new ArrayList<>();
int weightPos = 0;
for (int layer = 1; layer < layers.length; layer++) {
NeuralLayer neuralLayer = new NeuralLayer();
int rows = entities.size();
int columns = layers[layer];
List<List<Double>> weightMatrix = new ArrayList<>();
for (int column = 0; column < columns; column++) {
List<Double> weightVector = new ArrayList<>();
for (int row = 0; row < rows; row++) {
weightVector.add(weights.apply(weightPos + (row * columns) + column));
}
weightMatrix.add(weightVector);
}
weightPos += (rows * columns);
for (int column = 0; column < columns; column++) {
List<Double> weightVector = weightMatrix.get(column);
Double bias = weights.apply(weightPos);
Neuron neuron = NeuralNetworkUtil.createNeuron(entities, weightVector, bias).setId(String.valueOf(layer) + "/" + String.valueOf(column + 1));
neuralLayer.addNeurons(neuron);
weightPos++;
}
if (layer == (layers.length - 1)) {
neuralLayer.setActivationFunction(NeuralNetwork.ActivationFunction.IDENTITY).setNormalizationMethod(NeuralNetwork.NormalizationMethod.SOFTMAX);
}
neuralLayers.add(neuralLayer);
entities = neuralLayer.getNeurons();
}
if (weightPos != weights.size()) {
throw new IllegalArgumentException();
}
NeuralNetwork neuralNetwork = new NeuralNetwork(MiningFunction.CLASSIFICATION, NeuralNetwork.ActivationFunction.LOGISTIC, ModelUtil.createMiningSchema(categoricalLabel), neuralInputs, neuralLayers).setNeuralOutputs(NeuralNetworkUtil.createClassificationNeuralOutputs(entities, categoricalLabel));
return neuralNetwork;
}
Aggregations