Search in sources :

Example 1 with Array

use of org.dmg.pmml.Array in project shifu by ShifuML.

the class TreeNodePmmlElementCreator method convert.

public org.dmg.pmml.tree.Node convert(Node node, boolean isLeft, Split split) {
    org.dmg.pmml.tree.Node pmmlNode = new org.dmg.pmml.tree.Node();
    pmmlNode.setId(String.valueOf(node.getId()));
    if (node.getPredict() != null) {
        pmmlNode.setScore(String.valueOf(treeModel.isClassification() ? node.getPredict().getClassValue() : node.getPredict().getPredict()));
    }
    pmmlNode.setDefaultChild(null);
    Predicate predicate = null;
    ColumnConfig columnConfig = this.columnConfigList.get(split.getColumnNum());
    if (columnConfig.isNumerical()) {
        SimplePredicate p = new SimplePredicate();
        p.setValue(String.valueOf(split.getThreshold()));
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        if (isLeft) {
            p.setOperator(SimplePredicate.Operator.fromValue("lessThan"));
        } else {
            p.setOperator(SimplePredicate.Operator.fromValue("greaterOrEqual"));
        }
        predicate = p;
    } else if (columnConfig.isCategorical()) {
        SimpleSetPredicate p = new SimpleSetPredicate();
        Set<Short> childCategories = split.getLeftOrRightCategories();
        // TODO, how to support segment variable in tree model, here should be changed
        p.setField(new FieldName(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
        StringBuilder arrayStr = new StringBuilder();
        List<String> valueList = treeModel.getCategoricalColumnNameNames().get(columnConfig.getColumnNum());
        for (Short sh : childCategories) {
            if (sh >= valueList.size()) {
                arrayStr.append(" \"\"");
                continue;
            }
            String s = valueList.get(sh);
            arrayStr.append(" ");
            if (s.contains("\"")) {
                String tmp = s.replaceAll("\"", "\\\\\\\"");
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(tmp);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(tmp);
                }
            } else {
                if (s.contains(" ")) {
                    arrayStr.append("\"");
                    arrayStr.append(s);
                    arrayStr.append("\"");
                } else {
                    arrayStr.append(s);
                }
            }
        }
        Array array = new Array(Array.Type.fromValue("string"), arrayStr.toString().trim());
        p.setArray(array);
        if (isLeft) {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            }
        } else {
            if (split.isLeft()) {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isNotIn"));
            } else {
                p.setBooleanOperator(SimpleSetPredicate.BooleanOperator.fromValue("isIn"));
            }
        }
        predicate = p;
    }
    pmmlNode.setPredicate(predicate);
    if (node.getSplit() == null || node.isRealLeaf()) {
        return pmmlNode;
    }
    List<org.dmg.pmml.tree.Node> childList = pmmlNode.getNodes();
    org.dmg.pmml.tree.Node left = convert(node.getLeft(), true, node.getSplit());
    org.dmg.pmml.tree.Node right = convert(node.getRight(), false, node.getSplit());
    childList.add(left);
    childList.add(right);
    return pmmlNode;
}
Also used : Set(java.util.Set) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) Node(ml.shifu.shifu.core.dtrain.dt.Node) SimplePredicate(org.dmg.pmml.SimplePredicate) Predicate(org.dmg.pmml.Predicate) SimplePredicate(org.dmg.pmml.SimplePredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate) Array(org.dmg.pmml.Array) List(java.util.List) FieldName(org.dmg.pmml.FieldName)

Example 2 with Array

use of org.dmg.pmml.Array in project shifu by ShifuML.

the class ModelStatsCreator method build.

@Override
public ModelStats build(BasicML basicML) {
    ModelStats modelStats = new ModelStats();
    if (basicML instanceof BasicFloatNetwork) {
        BasicFloatNetwork bfn = (BasicFloatNetwork) basicML;
        Set<Integer> featureSet = bfn.getFeatureSet();
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect() && (CollectionUtils.isEmpty(featureSet) || featureSet.contains(columnConfig.getColumnNum()))) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion
                // as we need to address new stats variable
                // set simple column name in PMML
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    } else {
        for (ColumnConfig columnConfig : columnConfigList) {
            if (columnConfig.isFinalSelect()) {
                UnivariateStats univariateStats = new UnivariateStats();
                // here, no need to consider if column is in segment expansion as we need to address new stats
                // variable
                univariateStats.setField(FieldName.create(NormalUtils.getSimpleColumnName(columnConfig.getColumnName())));
                if (columnConfig.isCategorical()) {
                    DiscrStats discrStats = new DiscrStats();
                    Array countArray = createCountArray(columnConfig);
                    discrStats.addArrays(countArray);
                    if (!isConcise) {
                        List<Extension> extensions = createExtensions(columnConfig);
                        discrStats.addExtensions(extensions.toArray(new Extension[extensions.size()]));
                    }
                    univariateStats.setDiscrStats(discrStats);
                } else {
                    // numerical column
                    univariateStats.setNumericInfo(createNumericInfo(columnConfig));
                    if (!isConcise) {
                        univariateStats.setContStats(createConStats(columnConfig));
                    }
                }
                modelStats.addUnivariateStats(univariateStats);
            }
        }
    }
    return modelStats;
}
Also used : Array(org.dmg.pmml.Array) Extension(org.dmg.pmml.Extension) DiscrStats(org.dmg.pmml.DiscrStats) ColumnConfig(ml.shifu.shifu.container.obj.ColumnConfig) UnivariateStats(org.dmg.pmml.UnivariateStats) ModelStats(org.dmg.pmml.ModelStats) BasicFloatNetwork(ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)

Example 3 with Array

use of org.dmg.pmml.Array in project drools by kiegroup.

the class ModelUtilsTest method getObjectsFromArray.

@Test
public void getObjectsFromArray() {
    List<String> values = Arrays.asList("32", "11", "43");
    Array array = getArray(Array.Type.INT, values);
    List<Object> retrieved = ModelUtils.getObjectsFromArray(array);
    assertEquals(values.size(), retrieved.size());
    for (int i = 0; i < values.size(); i++) {
        Object obj = retrieved.get(i);
        assertTrue(obj instanceof Integer);
        Integer expected = Integer.valueOf(values.get(i));
        assertEquals(expected, obj);
    }
    values = Arrays.asList("just", "11", "fun");
    array = getArray(Array.Type.STRING, values);
    retrieved = ModelUtils.getObjectsFromArray(array);
    assertEquals(values.size(), retrieved.size());
    for (int i = 0; i < values.size(); i++) {
        Object obj = retrieved.get(i);
        assertTrue(obj instanceof String);
        assertEquals(values.get(i), obj);
    }
    values = Arrays.asList("23.11", "11", "123.123");
    array = getArray(Array.Type.REAL, values);
    retrieved = ModelUtils.getObjectsFromArray(array);
    assertEquals(values.size(), retrieved.size());
    for (int i = 0; i < values.size(); i++) {
        Object obj = retrieved.get(i);
        assertTrue(obj instanceof Double);
        Double expected = Double.valueOf(values.get(i));
        assertEquals(expected, obj);
    }
}
Also used : PMMLModelTestUtils.getArray(org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray) Array(org.dmg.pmml.Array) Test(org.junit.Test)

Example 4 with Array

use of org.dmg.pmml.Array in project drools by kiegroup.

the class PMMLModelTestUtils method getArray.

public static Array getArray(Array.Type arrayType, final List<String> values) {
    String arrayString = String.join(" ", values);
    Array toReturn = new Array(arrayType, arrayString);
    toReturn.setN(values.size());
    return toReturn;
}
Also used : Array(org.dmg.pmml.Array)

Example 5 with Array

use of org.dmg.pmml.Array in project drools by kiegroup.

the class PMMLModelTestUtils method getSimpleSetPredicate.

public static SimpleSetPredicate getSimpleSetPredicate(final String predicateName, final Array.Type arrayType, final List<String> values, final SimpleSetPredicate.BooleanOperator booleanOperator) {
    FieldName fieldName = FieldName.create(predicateName);
    SimpleSetPredicate toReturn = new SimpleSetPredicate();
    toReturn.setField(fieldName);
    toReturn.setBooleanOperator(booleanOperator);
    Array array = getArray(arrayType, values);
    toReturn.setArray(array);
    return toReturn;
}
Also used : Array(org.dmg.pmml.Array) FieldName(org.dmg.pmml.FieldName) SimpleSetPredicate(org.dmg.pmml.SimpleSetPredicate)

Aggregations

Array (org.dmg.pmml.Array)10 SimpleSetPredicate (org.dmg.pmml.SimpleSetPredicate)5 FieldName (org.dmg.pmml.FieldName)3 ColumnConfig (ml.shifu.shifu.container.obj.ColumnConfig)2 PMMLModelTestUtils.getArray (org.kie.pmml.compiler.api.testutils.PMMLModelTestUtils.getArray)2 ArrayList (java.util.ArrayList)1 List (java.util.List)1 Set (java.util.Set)1 BasicFloatNetwork (ml.shifu.shifu.core.dtrain.dataset.BasicFloatNetwork)1 Node (ml.shifu.shifu.core.dtrain.dt.Node)1 DiscrStats (org.dmg.pmml.DiscrStats)1 Extension (org.dmg.pmml.Extension)1 ModelStats (org.dmg.pmml.ModelStats)1 Predicate (org.dmg.pmml.Predicate)1 SimplePredicate (org.dmg.pmml.SimplePredicate)1 UnivariateStats (org.dmg.pmml.UnivariateStats)1 Test (org.junit.Test)1 KiePMMLException (org.kie.pmml.api.exceptions.KiePMMLException)1 KiePMMLSimpleSetPredicate (org.kie.pmml.commons.model.predicates.KiePMMLSimpleSetPredicate)1