use of org.apache.ignite.ml.catboost.CatboostClassificationModel in project ignite by apache.
the class CatboostClassificationModelParserTest method testParseAndPredict.
/**
* End-to-end test for {@code parse()} and {@code predict()} methods.
*/
@Test
public void testParseAndPredict() {
URL url = CatboostClassificationModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
if (url == null)
throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
ModelReader reader = new FileSystemModelReader(url.getPath());
try (CatboostClassificationModel mdl = mdlBuilder.build(reader, parser)) {
HashMap<String, Double> input = new HashMap<>();
input.put("ACTION", 1.0);
input.put("RESOURCE", 39353.0);
input.put("MGR_ID", 85475.0);
input.put("ROLE_ROLLUP_1", 117961.0);
input.put("ROLE_ROLLUP_2", 118300.0);
input.put("ROLE_DEPTNAME", 123472.0);
input.put("ROLE_TITLE", 117905.0);
input.put("ROLE_FAMILY_DESC", 117906.0);
input.put("ROLE_FAMILY", 290919.0);
input.put("ROLE_CODE", 117908.0);
double prediction = mdl.predict(VectorUtils.of(input));
assertEquals(0.9928904609329371, prediction, 1e-5);
}
}
Aggregations