Search in sources :

Example 6 with FieldRef

use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.

the class SVMConverter method encodeModel.

@Override
public SupportVectorMachineModel encodeModel(Schema schema) {
    RGenericVector svm = getObject();
    RDoubleVector type = (RDoubleVector) svm.getValue("type");
    RDoubleVector kernel = (RDoubleVector) svm.getValue("kernel");
    RDoubleVector degree = (RDoubleVector) svm.getValue("degree");
    RDoubleVector gamma = (RDoubleVector) svm.getValue("gamma");
    RDoubleVector coef0 = (RDoubleVector) svm.getValue("coef0");
    RGenericVector yScale = (RGenericVector) svm.getValue("y.scale");
    RIntegerVector nSv = (RIntegerVector) svm.getValue("nSV");
    RDoubleVector sv = (RDoubleVector) svm.getValue("SV");
    RDoubleVector rho = (RDoubleVector) svm.getValue("rho");
    RDoubleVector coefs = (RDoubleVector) svm.getValue("coefs");
    Type svmType = Type.values()[ValueUtil.asInt(type.asScalar())];
    Kernel svmKernel = Kernel.values()[ValueUtil.asInt(kernel.asScalar())];
    SupportVectorMachineModel supportVectorMachineModel;
    switch(svmType) {
        case C_CLASSIFICATION:
        case NU_CLASSIFICATION:
            {
                supportVectorMachineModel = encodeClassification(sv, nSv, rho, coefs, schema);
            }
            break;
        case ONE_CLASSIFICATION:
            {
                Transformation outlier = new OutlierTransformation() {

                    @Override
                    public Expression createExpression(FieldRef fieldRef) {
                        return PMMLUtil.createApply("lessOrEqual", fieldRef, PMMLUtil.createConstant(0d));
                    }
                };
                supportVectorMachineModel = encodeRegression(sv, rho, coefs, schema).setOutput(ModelUtil.createPredictedOutput(FieldName.create("decisionFunction"), OpType.CONTINUOUS, DataType.DOUBLE, outlier));
                if (yScale != null && yScale.size() > 0) {
                    throw new IllegalArgumentException();
                }
            }
            break;
        case EPS_REGRESSION:
        case NU_REGRESSION:
            {
                supportVectorMachineModel = encodeRegression(sv, rho, coefs, schema);
                if (yScale != null && yScale.size() > 0) {
                    RDoubleVector yScaledCenter = (RDoubleVector) yScale.getValue("scaled:center");
                    RDoubleVector yScaledScale = (RDoubleVector) yScale.getValue("scaled:scale");
                    supportVectorMachineModel.setTargets(ModelUtil.createRescaleTargets(-1d * yScaledScale.asScalar(), yScaledCenter.asScalar(), (ContinuousLabel) schema.getLabel()));
                }
            }
            break;
        default:
            throw new IllegalArgumentException();
    }
    supportVectorMachineModel.setKernel(svmKernel.createKernel(degree.asScalar(), gamma.asScalar(), coef0.asScalar()));
    return supportVectorMachineModel;
}
Also used : OpType(org.dmg.pmml.OpType) DataType(org.dmg.pmml.DataType) Transformation(org.jpmml.converter.Transformation) OutlierTransformation(org.jpmml.converter.OutlierTransformation) FieldRef(org.dmg.pmml.FieldRef) OutlierTransformation(org.jpmml.converter.OutlierTransformation) Expression(org.dmg.pmml.Expression) SupportVectorMachineModel(org.dmg.pmml.support_vector_machine.SupportVectorMachineModel) RadialBasisKernel(org.dmg.pmml.support_vector_machine.RadialBasisKernel) PolynomialKernel(org.dmg.pmml.support_vector_machine.PolynomialKernel) LinearKernel(org.dmg.pmml.support_vector_machine.LinearKernel) SigmoidKernel(org.dmg.pmml.support_vector_machine.SigmoidKernel)

Example 7 with FieldRef

use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.

the class ExpressionCompactorTest method compactLogicalExpression.

@Test
public void compactLogicalExpression() {
    FieldRef fieldRef = new FieldRef(FieldName.create("x"));
    Apply first = createApply("equal", fieldRef, createConstant("1"));
    Apply leftLeftChild = createApply("equal", fieldRef, createConstant("2/L/L"));
    Apply leftRightChild = createApply("equal", fieldRef, createConstant("2/L/R"));
    Apply leftChild = createApply("or", leftLeftChild, leftRightChild);
    Apply rightChild = createApply("equal", fieldRef, createConstant("2/R"));
    Apply second = createApply("or", leftChild, rightChild);
    Apply third = createApply("equal", fieldRef, createConstant("3"));
    Apply apply = compact(createApply("or", first, second, third));
    assertEquals(Arrays.asList(first, leftLeftChild, leftRightChild, rightChild, third), apply.getExpressions());
}
Also used : FieldRef(org.dmg.pmml.FieldRef) Apply(org.dmg.pmml.Apply) Test(org.junit.Test)

Example 8 with FieldRef

use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.

the class ExpressionTranslatorTest method translateArithmeticExpressionChain.

@Test
public void translateArithmeticExpressionChain() {
    Apply apply = (Apply) ExpressionTranslator.translateExpression("A + B - X + C");
    List<Expression> expressions = checkApply(apply, "+", Apply.class, FieldRef.class);
    Expression left = expressions.get(0);
    Expression right = expressions.get(1);
    expressions = checkApply((Apply) left, "-", Apply.class, FieldRef.class);
    checkFieldRef((FieldRef) right, FieldName.create("C"));
    left = expressions.get(0);
    right = expressions.get(1);
    expressions = checkApply((Apply) left, "+", FieldRef.class, FieldRef.class);
    checkFieldRef((FieldRef) right, FieldName.create("X"));
    left = expressions.get(0);
    right = expressions.get(1);
    checkFieldRef((FieldRef) left, FieldName.create("A"));
    checkFieldRef((FieldRef) right, FieldName.create("B"));
}
Also used : FieldRef(org.dmg.pmml.FieldRef) Expression(org.dmg.pmml.Expression) Apply(org.dmg.pmml.Apply) Test(org.junit.Test)

Example 9 with FieldRef

use of org.dmg.pmml.FieldRef in project jpmml-r by jpmml.

the class PreProcessEncoder method encodeExpression.

private Expression encodeExpression(Feature feature) {
    FieldName name = feature.getName();
    Expression expression = feature.ref();
    List<Double> ranges = this.ranges.get(name);
    if (ranges != null) {
        Double min = ranges.get(0);
        Double max = ranges.get(1);
        expression = PMMLUtil.createApply("/", PMMLUtil.createApply("-", expression, PMMLUtil.createConstant(min)), PMMLUtil.createConstant(max - min));
    }
    Double mean = this.mean.get(name);
    if (mean != null) {
        expression = PMMLUtil.createApply("-", expression, PMMLUtil.createConstant(mean));
    }
    Double std = this.std.get(name);
    if (std != null) {
        expression = PMMLUtil.createApply("/", expression, PMMLUtil.createConstant(std));
    }
    Double median = this.median.get(name);
    if (median != null) {
        expression = PMMLUtil.createApply("if", PMMLUtil.createApply("isNotMissing", new FieldRef(name)), expression, PMMLUtil.createConstant(median));
    }
    if (expression instanceof FieldRef) {
        return null;
    }
    return expression;
}
Also used : FieldRef(org.dmg.pmml.FieldRef) Expression(org.dmg.pmml.Expression) FieldName(org.dmg.pmml.FieldName)

Example 10 with FieldRef

use of org.dmg.pmml.FieldRef in project jpmml-sparkml by jpmml.

the class CountVectorizerModelConverter method encodeFeatures.

@Override
public List<Feature> encodeFeatures(SparkMLEncoder encoder) {
    CountVectorizerModel transformer = getTransformer();
    DocumentFeature documentFeature = (DocumentFeature) encoder.getOnlyFeature(transformer.getInputCol());
    ParameterField documentField = new ParameterField(FieldName.create("document"));
    ParameterField termField = new ParameterField(FieldName.create("term"));
    TextIndex textIndex = new TextIndex(documentField.getName()).setTokenize(Boolean.TRUE).setWordSeparatorCharacterRE(documentFeature.getWordSeparatorRE()).setLocalTermWeights(transformer.getBinary() ? TextIndex.LocalTermWeights.BINARY : null).setExpression(new FieldRef(termField.getName()));
    Set<DocumentFeature.StopWordSet> stopWordSets = documentFeature.getStopWordSets();
    for (DocumentFeature.StopWordSet stopWordSet : stopWordSets) {
        if (stopWordSet.isEmpty()) {
            continue;
        }
        DocumentBuilder documentBuilder = DOMUtil.createDocumentBuilder();
        String tokenRE;
        String wordSeparatorRE = documentFeature.getWordSeparatorRE();
        switch(wordSeparatorRE) {
            case "\\s+":
                tokenRE = "(^|\\s+)\\p{Punct}*(" + JOINER.join(stopWordSet) + ")\\p{Punct}*(\\s+|$)";
                break;
            case "\\W+":
                tokenRE = "(\\W+)(" + JOINER.join(stopWordSet) + ")(\\W+)";
                break;
            default:
                throw new IllegalArgumentException("Expected \"\\s+\" or \"\\W+\" as splitter regex pattern, got \"" + wordSeparatorRE + "\"");
        }
        InlineTable inlineTable = new InlineTable().addRows(DOMUtil.createRow(documentBuilder, Arrays.asList("string", "stem", "regex"), Arrays.asList(tokenRE, " ", "true")));
        TextIndexNormalization textIndexNormalization = new TextIndexNormalization().setCaseSensitive(stopWordSet.isCaseSensitive()).setRecursive(// Handles consecutive matches. See http://stackoverflow.com/a/25085385
        Boolean.TRUE).setInlineTable(inlineTable);
        textIndex.addTextIndexNormalizations(textIndexNormalization);
    }
    DefineFunction defineFunction = new DefineFunction("tf" + "@" + String.valueOf(CountVectorizerModelConverter.SEQUENCE.getAndIncrement()), OpType.CONTINUOUS, null).setDataType(DataType.INTEGER).addParameterFields(documentField, termField).setExpression(textIndex);
    encoder.addDefineFunction(defineFunction);
    List<Feature> result = new ArrayList<>();
    String[] vocabulary = transformer.vocabulary();
    for (int i = 0; i < vocabulary.length; i++) {
        String term = vocabulary[i];
        if (TermUtil.hasPunctuation(term)) {
            throw new IllegalArgumentException(term);
        }
        result.add(new TermFeature(encoder, defineFunction, documentFeature, term));
    }
    return result;
}
Also used : InlineTable(org.dmg.pmml.InlineTable) FieldRef(org.dmg.pmml.FieldRef) TextIndex(org.dmg.pmml.TextIndex) DocumentFeature(org.jpmml.sparkml.DocumentFeature) ArrayList(java.util.ArrayList) Feature(org.jpmml.converter.Feature) DocumentFeature(org.jpmml.sparkml.DocumentFeature) TermFeature(org.jpmml.sparkml.TermFeature) TermFeature(org.jpmml.sparkml.TermFeature) TextIndexNormalization(org.dmg.pmml.TextIndexNormalization) CountVectorizerModel(org.apache.spark.ml.feature.CountVectorizerModel) DocumentBuilder(javax.xml.parsers.DocumentBuilder) DefineFunction(org.dmg.pmml.DefineFunction) ParameterField(org.dmg.pmml.ParameterField)

Aggregations

FieldRef (org.dmg.pmml.FieldRef)10 Apply (org.dmg.pmml.Apply)6 Expression (org.dmg.pmml.Expression)5 Test (org.junit.Test)4 ArrayList (java.util.ArrayList)3 Constant (org.dmg.pmml.Constant)3 DefineFunction (org.dmg.pmml.DefineFunction)2 FieldName (org.dmg.pmml.FieldName)2 ParameterField (org.dmg.pmml.ParameterField)2 Feature (org.jpmml.converter.Feature)2 Transformation (org.jpmml.converter.Transformation)2 DocumentBuilder (javax.xml.parsers.DocumentBuilder)1 CountVectorizerModel (org.apache.spark.ml.feature.CountVectorizerModel)1 DataType (org.dmg.pmml.DataType)1 DerivedField (org.dmg.pmml.DerivedField)1 InlineTable (org.dmg.pmml.InlineTable)1 OpType (org.dmg.pmml.OpType)1 TextIndex (org.dmg.pmml.TextIndex)1 TextIndexNormalization (org.dmg.pmml.TextIndexNormalization)1 MiningModel (org.dmg.pmml.mining.MiningModel)1