Search in sources :

Example 1 with XGModelComposition

use of org.apache.ignite.ml.xgboost.XGModelComposition in project ignite by apache.

the class XGModelVisitor method visitXgModel.

/**
 * {@inheritDoc}
 */
@Override
public XGModelComposition visitXgModel(XGBoostModelParser.XgModelContext ctx) {
    List<DecisionTreeNode> trees = new ArrayList<>();
    Set<String> featureNames = new HashSet<>();
    for (XGBoostModelParser.XgTreeContext treeCtx : ctx.xgTree()) featureNames.addAll(treeDictionaryVisitor.visitXgTree(treeCtx));
    Map<String, Integer> dict = buildDictionary(featureNames);
    XGTreeVisitor treeVisitor = new XGTreeVisitor(dict);
    for (XGBoostModelParser.XgTreeContext treeCtx : ctx.xgTree()) {
        DecisionTreeNode treeNode = treeVisitor.visitXgTree(treeCtx);
        trees.add(treeNode);
    }
    return new XGModelComposition(dict, trees);
}
Also used : XGModelComposition(org.apache.ignite.ml.xgboost.XGModelComposition) ArrayList(java.util.ArrayList) XGBoostModelParser(org.apache.ignite.ml.xgboost.parser.XGBoostModelParser) DecisionTreeNode(org.apache.ignite.ml.tree.DecisionTreeNode) HashSet(java.util.HashSet)

Example 2 with XGModelComposition

use of org.apache.ignite.ml.xgboost.XGModelComposition in project ignite by apache.

the class XGBoostModelParserTest method testParseAndPredict.

/**
 * End-to-end test for {@code parse()} and {@code predict()} methods.
 */
@Test
public void testParseAndPredict() {
    URL url = XGBoostModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
    if (url == null)
        throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
    ModelReader reader = new FileSystemModelReader(url.getPath());
    try (XGModelComposition mdl = mdlBuilder.build(reader, parser);
        Scanner testDataScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-data.txt"));
        Scanner testExpResultsScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-expected-results.txt"))) {
        while (testDataScanner.hasNextLine()) {
            assertTrue(testExpResultsScanner.hasNextLine());
            String testDataStr = testDataScanner.nextLine();
            String testExpResultsStr = testExpResultsScanner.nextLine();
            HashMap<String, Double> testObj = new HashMap<>();
            for (String keyValueString : testDataStr.split(" ")) {
                String[] keyVal = keyValueString.split(":");
                if (keyVal.length == 2)
                    testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
            }
            double prediction = mdl.predict(VectorUtils.of(testObj));
            double expPrediction = Double.parseDouble(testExpResultsStr);
            assertEquals(expPrediction, prediction, 1e-6);
        }
    }
}
Also used : Scanner(java.util.Scanner) XGModelComposition(org.apache.ignite.ml.xgboost.XGModelComposition) HashMap(java.util.HashMap) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) URL(java.net.URL) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) Test(org.junit.Test)

Aggregations

XGModelComposition (org.apache.ignite.ml.xgboost.XGModelComposition)2 URL (java.net.URL)1 ArrayList (java.util.ArrayList)1 HashMap (java.util.HashMap)1 HashSet (java.util.HashSet)1 Scanner (java.util.Scanner)1 FileSystemModelReader (org.apache.ignite.ml.inference.reader.FileSystemModelReader)1 ModelReader (org.apache.ignite.ml.inference.reader.ModelReader)1 DecisionTreeNode (org.apache.ignite.ml.tree.DecisionTreeNode)1 XGBoostModelParser (org.apache.ignite.ml.xgboost.parser.XGBoostModelParser)1 Test (org.junit.Test)1