use of org.apache.ignite.ml.xgboost.XGModelComposition in project ignite by apache.
the class XGModelVisitor method visitXgModel.
/**
* {@inheritDoc}
*/
@Override
public XGModelComposition visitXgModel(XGBoostModelParser.XgModelContext ctx) {
List<DecisionTreeNode> trees = new ArrayList<>();
Set<String> featureNames = new HashSet<>();
for (XGBoostModelParser.XgTreeContext treeCtx : ctx.xgTree()) featureNames.addAll(treeDictionaryVisitor.visitXgTree(treeCtx));
Map<String, Integer> dict = buildDictionary(featureNames);
XGTreeVisitor treeVisitor = new XGTreeVisitor(dict);
for (XGBoostModelParser.XgTreeContext treeCtx : ctx.xgTree()) {
DecisionTreeNode treeNode = treeVisitor.visitXgTree(treeCtx);
trees.add(treeNode);
}
return new XGModelComposition(dict, trees);
}
use of org.apache.ignite.ml.xgboost.XGModelComposition in project ignite by apache.
the class XGBoostModelParserTest method testParseAndPredict.
/**
* End-to-end test for {@code parse()} and {@code predict()} methods.
*/
@Test
public void testParseAndPredict() {
URL url = XGBoostModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
if (url == null)
throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
ModelReader reader = new FileSystemModelReader(url.getPath());
try (XGModelComposition mdl = mdlBuilder.build(reader, parser);
Scanner testDataScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-data.txt"));
Scanner testExpResultsScanner = new Scanner(XGBoostModelParserTest.class.getClassLoader().getResourceAsStream("datasets/agaricus-test-expected-results.txt"))) {
while (testDataScanner.hasNextLine()) {
assertTrue(testExpResultsScanner.hasNextLine());
String testDataStr = testDataScanner.nextLine();
String testExpResultsStr = testExpResultsScanner.nextLine();
HashMap<String, Double> testObj = new HashMap<>();
for (String keyValueString : testDataStr.split(" ")) {
String[] keyVal = keyValueString.split(":");
if (keyVal.length == 2)
testObj.put("f" + keyVal[0], Double.parseDouble(keyVal[1]));
}
double prediction = mdl.predict(VectorUtils.of(testObj));
double expPrediction = Double.parseDouble(testExpResultsStr);
assertEquals(expPrediction, prediction, 1e-6);
}
}
}
Aggregations