use of org.dmg.pmml.VerificationField in project jpmml-r by jpmml.
the class ModelConverter method encodeVerificationData.
protected static Map<VerificationField, List<?>> encodeVerificationData(List<? extends RExp> columns, List<String> names) {
Map<VerificationField, List<?>> result = new LinkedHashMap<>();
for (int i = 0; i < columns.size(); i++) {
String name = names.get(i);
RVector<?> column = (RVector<?>) columns.get(i);
List<?> values;
if (column instanceof RDoubleVector) {
Function<Double, Double> function = new Function<Double, Double>() {
@Override
public Double apply(Double value) {
if (value.isNaN()) {
return null;
}
return value;
}
};
values = Lists.transform((List) column.getValues(), function);
} else if (column instanceof RFactorVector) {
RFactorVector factor = (RFactorVector) column;
values = factor.getFactorValues();
} else {
values = column.getValues();
}
VerificationField verificationField = ModelUtil.createVerificationField(name);
result.put(verificationField, values);
}
return result;
}
use of org.dmg.pmml.VerificationField in project jpmml-r by jpmml.
the class ModelConverter method encodePMML.
@Override
public PMML encodePMML(RExpEncoder encoder) {
RExp object = getObject();
RGenericVector verification = null;
if (object instanceof S4Object) {
S4Object model = (S4Object) object;
verification = model.getGenericAttribute("verification", false);
} else if (object instanceof RGenericVector) {
RGenericVector model = (RGenericVector) object;
verification = model.getGenericElement("verification", false);
}
encodeSchema(encoder);
Schema schema = encoder.createSchema();
Model model = encode(schema);
verification: if (verification != null) {
RDoubleVector precision = verification.getDoubleElement("precision");
RDoubleVector zeroThreshold = verification.getDoubleElement("zeroThreshold");
VerificationMap data = new VerificationMap(precision.asScalar(), zeroThreshold.asScalar());
RGenericVector activeValues = verification.getGenericElement("active_values");
RGenericVector targetValues = verification.getGenericElement("target_values", false);
RGenericVector outputValues = verification.getGenericElement("output_values", false);
if (activeValues != null) {
data.putInputData(encodeActiveValues(activeValues));
}
if (targetValues != null && outputValues == null) {
Label label = schema.getLabel();
String name = label.getName();
Collection<VerificationField> verificationFields = data.keySet();
for (Iterator<VerificationField> verificationFieldIt = verificationFields.iterator(); verificationFieldIt.hasNext(); ) {
VerificationField verificationField = verificationFieldIt.next();
if ((verificationField.requireField()).equals(name)) {
verificationFieldIt.remove();
}
}
data.putResultData(encodeTargetValues(targetValues, label));
} else if (outputValues != null) {
data.putResultData(encodeOutputValues(outputValues));
} else {
break verification;
}
model.setModelVerification(ModelUtil.createModelVerification(data));
}
PMML pmml = encoder.encodePMML(model);
return pmml;
}
use of org.dmg.pmml.VerificationField in project jpmml-sparkml by jpmml.
the class PMMLBuilder method build.
public PMML build() {
StructType schema = getSchema();
PipelineModel pipelineModel = getPipelineModel();
Map<RegexKey, ? extends Map<String, ?>> options = getOptions();
Verification verification = getVerification();
ConverterFactory converterFactory = new ConverterFactory(options);
SparkMLEncoder encoder = new SparkMLEncoder(schema, converterFactory);
Map<FieldName, DerivedField> derivedFields = encoder.getDerivedFields();
List<org.dmg.pmml.Model> models = new ArrayList<>();
List<String> predictionColumns = new ArrayList<>();
List<String> probabilityColumns = new ArrayList<>();
// Transformations preceding the last model
List<FieldName> preProcessorNames = Collections.emptyList();
Iterable<Transformer> transformers = getTransformers(pipelineModel);
for (Transformer transformer : transformers) {
TransformerConverter<?> converter = converterFactory.newConverter(transformer);
if (converter instanceof FeatureConverter) {
FeatureConverter<?> featureConverter = (FeatureConverter<?>) converter;
featureConverter.registerFeatures(encoder);
} else if (converter instanceof ModelConverter) {
ModelConverter<?> modelConverter = (ModelConverter<?>) converter;
org.dmg.pmml.Model model = modelConverter.registerModel(encoder);
models.add(model);
featureImportances: if (modelConverter instanceof HasFeatureImportances) {
HasFeatureImportances hasFeatureImportances = (HasFeatureImportances) modelConverter;
Boolean estimateFeatureImportances = (Boolean) modelConverter.getOption(HasTreeOptions.OPTION_ESTIMATE_FEATURE_IMPORTANCES, Boolean.FALSE);
if (!estimateFeatureImportances) {
break featureImportances;
}
List<Double> featureImportances = VectorUtil.toList(hasFeatureImportances.getFeatureImportances());
List<Feature> features = modelConverter.getFeatures(encoder);
SchemaUtil.checkSize(featureImportances.size(), features);
for (int i = 0; i < featureImportances.size(); i++) {
Double featureImportance = featureImportances.get(i);
Feature feature = features.get(i);
encoder.addFeatureImportance(model, feature, featureImportance);
}
}
hasPredictionCol: if (transformer instanceof HasPredictionCol) {
HasPredictionCol hasPredictionCol = (HasPredictionCol) transformer;
// XXX
if ((transformer instanceof GeneralizedLinearRegressionModel) && (MiningFunction.CLASSIFICATION).equals(model.getMiningFunction())) {
break hasPredictionCol;
}
predictionColumns.add(hasPredictionCol.getPredictionCol());
}
if (transformer instanceof HasProbabilityCol) {
HasProbabilityCol hasProbabilityCol = (HasProbabilityCol) transformer;
probabilityColumns.add(hasProbabilityCol.getProbabilityCol());
}
preProcessorNames = new ArrayList<>(derivedFields.keySet());
} else {
throw new IllegalArgumentException("Expected a subclass of " + FeatureConverter.class.getName() + " or " + ModelConverter.class.getName() + ", got " + (converter != null ? ("class " + (converter.getClass()).getName()) : null));
}
}
// Transformations following the last model
List<FieldName> postProcessorNames = new ArrayList<>(derivedFields.keySet());
postProcessorNames.removeAll(preProcessorNames);
org.dmg.pmml.Model model;
if (models.size() == 0) {
model = null;
} else if (models.size() == 1) {
model = Iterables.getOnlyElement(models);
} else {
model = MiningModelUtil.createModelChain(models);
}
if ((model != null) && (postProcessorNames.size() > 0)) {
org.dmg.pmml.Model finalModel = MiningModelUtil.getFinalModel(model);
Output output = ModelUtil.ensureOutput(finalModel);
for (FieldName postProcessorName : postProcessorNames) {
DerivedField derivedField = derivedFields.get(postProcessorName);
encoder.removeDerivedField(postProcessorName);
OutputField outputField = new OutputField(derivedField.getName(), derivedField.getOpType(), derivedField.getDataType()).setResultFeature(ResultFeature.TRANSFORMED_VALUE).setExpression(derivedField.getExpression());
output.addOutputFields(outputField);
}
}
PMML pmml = encoder.encodePMML(model);
if ((model != null) && (predictionColumns.size() > 0 || probabilityColumns.size() > 0) && (verification != null)) {
Dataset<Row> dataset = verification.getDataset();
Dataset<Row> transformedDataset = verification.getTransformedDataset();
Double precision = verification.getPrecision();
Double zeroThreshold = verification.getZeroThreshold();
List<String> inputColumns = new ArrayList<>();
MiningSchema miningSchema = model.getMiningSchema();
List<MiningField> miningFields = miningSchema.getMiningFields();
for (MiningField miningField : miningFields) {
MiningField.UsageType usageType = miningField.getUsageType();
switch(usageType) {
case ACTIVE:
FieldName name = miningField.getName();
inputColumns.add(name.getValue());
break;
default:
break;
}
}
Map<VerificationField, List<?>> data = new LinkedHashMap<>();
for (String inputColumn : inputColumns) {
VerificationField verificationField = ModelUtil.createVerificationField(FieldName.create(inputColumn));
data.put(verificationField, getColumn(dataset, inputColumn));
}
for (String predictionColumn : predictionColumns) {
Feature feature = encoder.getOnlyFeature(predictionColumn);
VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
data.put(verificationField, getColumn(transformedDataset, predictionColumn));
}
for (String probabilityColumn : probabilityColumns) {
List<Feature> features = encoder.getFeatures(probabilityColumn);
for (int i = 0; i < features.size(); i++) {
Feature feature = features.get(i);
VerificationField verificationField = ModelUtil.createVerificationField(feature.getName()).setPrecision(precision).setZeroThreshold(zeroThreshold);
data.put(verificationField, getVectorColumn(transformedDataset, probabilityColumn, i));
}
}
model.setModelVerification(ModelUtil.createModelVerification(data));
}
return pmml;
}
Aggregations