use of org.dmg.pmml.naive_bayes.BayesOutput in project jpmml-r by jpmml.
the class NaiveBayesConverter method encodeModel.
@Override
public Model encodeModel(Schema schema) {
RGenericVector naiveBayes = getObject();
RIntegerVector apriori = naiveBayes.getIntegerElement("apriori");
RGenericVector tables = naiveBayes.getGenericElement("tables");
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
List<? extends Feature> features = schema.getFeatures();
BayesInputs bayesInputs = new BayesInputs();
for (int i = 0; i < features.size(); i++) {
Feature feature = features.get(i);
String name = feature.getName();
RDoubleVector table = tables.getDoubleElement(name);
RStringVector tableRows = table.dimnames(0);
RStringVector tableColumns = table.dimnames(1);
BayesInput bayesInput = new BayesInput(name, null, null);
if (feature instanceof CategoricalFeature) {
CategoricalFeature categoricalFeature = (CategoricalFeature) feature;
for (int column = 0; column < tableColumns.size(); column++) {
TargetValueCounts targetValueCounts = new TargetValueCounts();
List<Double> probabilities = FortranMatrixUtil.getColumn(table.getValues(), tableRows.size(), tableColumns.size(), column);
for (int row = 0; row < tableRows.size(); row++) {
double count = apriori.getValue(row) * probabilities.get(row);
TargetValueCount targetValueCount = new TargetValueCount(tableRows.getValue(row), count);
targetValueCounts.addTargetValueCounts(targetValueCount);
}
PairCounts pairCounts = new PairCounts(tableColumns.getValue(column), targetValueCounts);
bayesInput.addPairCounts(pairCounts);
}
} else if (feature instanceof ContinuousFeature) {
ContinuousFeature continuousFeature = (ContinuousFeature) feature;
TargetValueStats targetValueStats = new TargetValueStats();
for (int row = 0; row < tableRows.size(); row++) {
List<Double> stats = FortranMatrixUtil.getRow(table.getValues(), tableRows.size(), 2, row);
double mean = stats.get(0);
double variance = Math.pow(stats.get(1), 2);
TargetValueStat targetValueStat = new TargetValueStat(tableRows.getValue(row), new GaussianDistribution(mean, variance));
targetValueStats.addTargetValueStats(targetValueStat);
}
bayesInput.setTargetValueStats(targetValueStats);
} else {
throw new IllegalArgumentException();
}
bayesInputs.addBayesInputs(bayesInput);
}
BayesOutput bayesOutput = new BayesOutput().setField(categoricalLabel.getName());
{
TargetValueCounts targetValueCounts = new TargetValueCounts();
RStringVector aprioriRows = apriori.dimnames(0);
for (int row = 0; row < aprioriRows.size(); row++) {
int count = apriori.getValue(row);
TargetValueCount targetValueCount = new TargetValueCount(aprioriRows.getValue(row), count);
targetValueCounts.addTargetValueCounts(targetValueCount);
}
bayesOutput.setTargetValueCounts(targetValueCounts);
}
NaiveBayesModel naiveBayesModel = new NaiveBayesModel(0d, MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), bayesInputs, bayesOutput).setOutput(ModelUtil.createProbabilityOutput(DataType.DOUBLE, categoricalLabel));
return naiveBayesModel;
}
Aggregations