use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.
the class SVMConverter method encodeModel.
@Override
public SupportVectorMachineModel encodeModel(Schema schema) {
RGenericVector svm = getObject();
RDoubleVector type = (RDoubleVector) svm.getValue("type");
RDoubleVector kernel = (RDoubleVector) svm.getValue("kernel");
RDoubleVector degree = (RDoubleVector) svm.getValue("degree");
RDoubleVector gamma = (RDoubleVector) svm.getValue("gamma");
RDoubleVector coef0 = (RDoubleVector) svm.getValue("coef0");
RGenericVector yScale = (RGenericVector) svm.getValue("y.scale");
RIntegerVector nSv = (RIntegerVector) svm.getValue("nSV");
RDoubleVector sv = (RDoubleVector) svm.getValue("SV");
RDoubleVector rho = (RDoubleVector) svm.getValue("rho");
RDoubleVector coefs = (RDoubleVector) svm.getValue("coefs");
Type svmType = Type.values()[ValueUtil.asInt(type.asScalar())];
Kernel svmKernel = Kernel.values()[ValueUtil.asInt(kernel.asScalar())];
SupportVectorMachineModel supportVectorMachineModel;
switch(svmType) {
case C_CLASSIFICATION:
case NU_CLASSIFICATION:
{
supportVectorMachineModel = encodeClassification(sv, nSv, rho, coefs, schema);
}
break;
case ONE_CLASSIFICATION:
{
Transformation outlier = new OutlierTransformation() {
@Override
public Expression createExpression(FieldRef fieldRef) {
return PMMLUtil.createApply("lessOrEqual", fieldRef, PMMLUtil.createConstant(0d));
}
};
supportVectorMachineModel = encodeRegression(sv, rho, coefs, schema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("decisionFunction"), OpType.CONTINUOUS, DataType.DOUBLE, outlier));
if (yScale != null && yScale.size() > 0) {
throw new IllegalArgumentException();
}
}
break;
case EPS_REGRESSION:
case NU_REGRESSION:
{
supportVectorMachineModel = encodeRegression(sv, rho, coefs, schema);
if (yScale != null && yScale.size() > 0) {
RDoubleVector yScaledCenter = (RDoubleVector) yScale.getValue("scaled:center");
RDoubleVector yScaledScale = (RDoubleVector) yScale.getValue("scaled:scale");
supportVectorMachineModel.setTargets(ModelUtil.createRescaleTargets(-1d * yScaledScale.asScalar(), yScaledCenter.asScalar(), (ContinuousLabel) schema.getLabel()));
}
}
break;
default:
throw new IllegalArgumentException();
}
supportVectorMachineModel.setKernel(svmKernel.createKernel(degree.asScalar(), gamma.asScalar(), coef0.asScalar()));
return supportVectorMachineModel;
}
use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.
the class ExpressionCompactorTest method compactLogicalExpression.
@Test
public void compactLogicalExpression() {
FieldRef fieldRef = new FieldRef(FieldName.create("x"));
Apply first = createApply("equal", fieldRef, createConstant("1"));
Apply leftLeftChild = createApply("equal", fieldRef, createConstant("2/L/L"));
Apply leftRightChild = createApply("equal", fieldRef, createConstant("2/L/R"));
Apply leftChild = createApply("or", leftLeftChild, leftRightChild);
Apply rightChild = createApply("equal", fieldRef, createConstant("2/R"));
Apply second = createApply("or", leftChild, rightChild);
Apply third = createApply("equal", fieldRef, createConstant("3"));
Apply apply = compact(createApply("or", first, second, third));
assertEquals(Arrays.asList(first, leftLeftChild, leftRightChild, rightChild, third), apply.getExpressions());
}
use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.
the class ExpressionTranslatorTest method translateArithmeticExpressionChain.
@Test
public void translateArithmeticExpressionChain() {
Apply apply = (Apply) ExpressionTranslator.translateExpression("A + B - X + C");
List<Expression> expressions = checkApply(apply, "+", Apply.class, FieldRef.class);
Expression left = expressions.get(0);
Expression right = expressions.get(1);
expressions = checkApply((Apply) left, "-", Apply.class, FieldRef.class);
checkFieldRef((FieldRef) right, FieldName.create("C"));
left = expressions.get(0);
right = expressions.get(1);
expressions = checkApply((Apply) left, "+", FieldRef.class, FieldRef.class);
checkFieldRef((FieldRef) right, FieldName.create("X"));
left = expressions.get(0);
right = expressions.get(1);
checkFieldRef((FieldRef) left, FieldName.create("A"));
checkFieldRef((FieldRef) right, FieldName.create("B"));
}
use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.
the class PreProcessEncoder method encodeExpression.
private Expression encodeExpression(Feature feature) {
FieldName name = feature.getName();
Expression expression = feature.ref();
List<Double> ranges = this.ranges.get(name);
if (ranges != null) {
Double min = ranges.get(0);
Double max = ranges.get(1);
expression = PMMLUtil.createApply("/", PMMLUtil.createApply("-", expression, PMMLUtil.createConstant(min)), PMMLUtil.createConstant(max - min));
}
Double mean = this.mean.get(name);
if (mean != null) {
expression = PMMLUtil.createApply("-", expression, PMMLUtil.createConstant(mean));
}
Double std = this.std.get(name);
if (std != null) {
expression = PMMLUtil.createApply("/", expression, PMMLUtil.createConstant(std));
}
Double median = this.median.get(name);
if (median != null) {
expression = PMMLUtil.createApply("if", PMMLUtil.createApply("isNotMissing", new FieldRef(name)), expression, PMMLUtil.createConstant(median));
}
if (expression instanceof FieldRef) {
return null;
}
return expression;
}
use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.
the class CountVectorizerModelConverter method encodeFeatures.
@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
CountVectorizerModel transformer = getTransformer();
DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
ParameterField documentField = new ParameterField(FieldName.create("document"));
ParameterField termField = new ParameterField(FieldName.create("term"));
TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
if (stopWordSet.isEmpty()) {
continue;
}
DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
String tokenRE;
String wordSeparatorRE = documentFeature.getWordSeparatorRE();
switch(wordSeparatorRE) {
case "\\s+":
tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
break;
case "\\W+":
tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
break;
default:
throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
}
InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
Boolean.TRUE).setInlineTable(inlineTable);
textIndex.addTextIndexNormalizations(textIndexNormalization);
}
DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
encoder.addDefineFunction(defineFunction);
List<Feature> result = new ArrayList<>();
String[] vocabulary = transformer.vocabulary();
for (int i = 0; i < vocabulary.length; i++) {
String term = vocabulary[i];
if (TermUtil.hasPunctuation(term)) {
throw new IllegalArgumentException(term);
}
result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
}
return result;
}
Aggregations