use of org.dmg.pmml.regression.RegressionTable in project jpmml-sparkml by jpmml.
the class LogisticRegressionModelConverter method encodeModel.
@Override
public RegressionModel encodeModel(Schema schema) {
LogisticRegressionModel model = getTransformer();
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
if (categoricalLabel.size() == 2) {
RegressionModel regressionModel = RegressionModelUtil.createBinaryLogisticClassification(schema.getFeatures(), VectorUtil.toList(model.coefficients()), model.intercept(), RegressionModel.NormalizationMethod.LOGIT, true, schema).setOutput(null);
return regressionModel;
} else if (categoricalLabel.size() > 2) {
Matrix coefficientMatrix = model.coefficientMatrix();
Vector interceptVector = model.interceptVector();
List<? extends Feature> features = schema.getFeatures();
List<RegressionTable> regressionTables = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(features, MatrixUtil.getRow(coefficientMatrix, i), interceptVector.apply(i)).setTargetCategory(categoricalLabel.getValue(i));
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(RegressionModel.NormalizationMethod.SOFTMAX);
return regressionModel;
} else {
throw new IllegalArgumentException();
}
}
use of org.dmg.pmml.regression.RegressionTable in project pyramid by cheng-li.
the class MiningModelUtil method createClassification.
public static MiningModel createClassification(List<? extends Model> models, RegressionModel.NormalizationMethod normalizationMethod, boolean hasProbabilityDistribution, Schema schema) {
CategoricalLabel categoricalLabel = (CategoricalLabel) schema.getLabel();
// modified here
if (categoricalLabel.size() != models.size()) {
throw new IllegalArgumentException();
}
if (normalizationMethod != null) {
switch(normalizationMethod) {
case NONE:
case SIMPLEMAX:
case SOFTMAX:
break;
default:
throw new IllegalArgumentException();
}
}
MathContext mathContext = null;
List<RegressionTable> regressionTables = new ArrayList<>();
for (int i = 0; i < categoricalLabel.size(); i++) {
Model model = models.get(i);
MathContext modelMathContext = model.getMathContext();
if (modelMathContext == null) {
modelMathContext = MathContext.DOUBLE;
}
if (mathContext == null) {
mathContext = modelMathContext;
} else {
if (!Objects.equals(mathContext, modelMathContext)) {
throw new IllegalArgumentException();
}
}
Feature feature = MODEL_PREDICTION.apply(model);
RegressionTable regressionTable = RegressionModelUtil.createRegressionTable(Collections.singletonList(feature), Collections.singletonList(1d), null).setTargetCategory(categoricalLabel.getValue(i));
regressionTables.add(regressionTable);
}
RegressionModel regressionModel = new RegressionModel(MiningFunction.CLASSIFICATION, ModelUtil.createMiningSchema(categoricalLabel), regressionTables).setNormalizationMethod(normalizationMethod).setMathContext(ModelUtil.simplifyMathContext(mathContext)).setOutput(hasProbabilityDistribution ? ModelUtil.createProbabilityOutput(mathContext, categoricalLabel) : null);
List<Model> segmentationModels = new ArrayList<>(models);
segmentationModels.add(regressionModel);
return createModelChain(segmentationModels, schema);
}
use of org.dmg.pmml.regression.RegressionTable in project shifu by ShifuML.
the class PMMLAdapterCommonUtil method getRegressionTable.
/**
* Generate Regression Table based on the weight list, intercept and partial
* PMML model
*
* @param weights
* weight list for the Regression Table
* @param intercept
* the intercept
* @param pmmlModel
* partial PMMl model
* @return regression model instance
*/
public static RegressionModel getRegressionTable(final double[] weights, final double intercept, RegressionModel pmmlModel) {
RegressionTable table = new RegressionTable();
MiningSchema schema = pmmlModel.getMiningSchema();
// TODO may not need target field in LRModel
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
List<String> outputFields = getSchemaFieldViaUsageType(schema, UsageType.TARGET);
// TODO only one outputField, what if we have more than one outputField
pmmlModel.setTargetFieldName(new FieldName(outputFields.get(0)));
table.setTargetCategory(outputFields.get(0));
List<String> activeFields = getSchemaFieldViaUsageType(schema, UsageType.ACTIVE);
int index = 0;
for (DerivedField dField : pmmlModel.getLocalTransformations().getDerivedFields()) {
Expression expression = dField.getExpression();
if (expression instanceof NormContinuous) {
NormContinuous norm = (NormContinuous) expression;
if (activeFields.contains(norm.getField().getValue()))
table.addNumericPredictors(new NumericPredictor(dField.getName(), weights[index++]));
}
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
use of org.dmg.pmml.regression.RegressionTable in project shifu by ShifuML.
the class PMMLLRModelBuilder method adaptMLModelToPMML.
public RegressionModel adaptMLModelToPMML(ml.shifu.shifu.core.LR lr, RegressionModel pmmlModel) {
pmmlModel.setNormalizationMethod(NormalizationMethod.LOGIT);
pmmlModel.setMiningFunction(MiningFunction.REGRESSION);
RegressionTable table = new RegressionTable();
table.setIntercept(lr.getBias());
LocalTransformations lt = pmmlModel.getLocalTransformations();
List<DerivedField> df = lt.getDerivedFields();
HashMap<FieldName, FieldName> miningTransformMap = new HashMap<FieldName, FieldName>();
for (DerivedField dField : df) {
// Apply z-scale normalization on numerical variables
if (dField.getExpression() instanceof NormContinuous) {
miningTransformMap.put(((NormContinuous) dField.getExpression()).getField(), dField.getName());
} else // Apply bin map on categorical variables
if (dField.getExpression() instanceof MapValues) {
miningTransformMap.put(((MapValues) dField.getExpression()).getFieldColumnPairs().get(0).getField(), dField.getName());
} else if (dField.getExpression() instanceof Discretize) {
miningTransformMap.put(((Discretize) dField.getExpression()).getField(), dField.getName());
}
}
List<MiningField> miningList = pmmlModel.getMiningSchema().getMiningFields();
int index = 0;
for (int i = 0; i < miningList.size(); i++) {
MiningField mField = miningList.get(i);
if (mField.getUsageType() != UsageType.ACTIVE)
continue;
FieldName mFieldName = mField.getName();
FieldName fName = mFieldName;
while (miningTransformMap.containsKey(fName)) {
fName = miningTransformMap.get(fName);
}
NumericPredictor np = new NumericPredictor();
np.setName(fName);
np.setCoefficient(lr.getWeights()[index++]);
table.addNumericPredictors(np);
}
pmmlModel.addRegressionTables(table);
return pmmlModel;
}
Aggregations