Search in sources :

Example 1 with CatboostRegressionModel

use of org.apache.ignite.ml.catboost.CatboostRegressionModel in project ignite by apache.

the class CatboostRegressionModelParserTest method testParseAndPredict.

/**
 * End-to-end test for {@code parse()} and {@code predict()} methods.
 */
@Test
public void testParseAndPredict() {
    URL url = CatboostRegressionModelParserTest.class.getClassLoader().getResource(TEST_MODEL_RESOURCE);
    if (url == null)
        throw new IllegalStateException("File not found [resource_name=" + TEST_MODEL_RESOURCE + "]");
    ModelReader reader = new FileSystemModelReader(url.getPath());
    try (CatboostRegressionModel mdl = mdlBuilder.build(reader, parser)) {
        HashMap<String, Double> input = new HashMap<>();
        input.put("f_0", 0.02731d);
        input.put("f_1", 0.0d);
        input.put("f_2", 7.07d);
        input.put("f_3", 0d);
        input.put("f_4", 0.469d);
        input.put("f_5", 6.421d);
        input.put("f_6", 78.9d);
        input.put("f_7", 4.9671d);
        input.put("f_8", 2d);
        input.put("f_9", 242.0d);
        input.put("f_10", 17.8d);
        input.put("f_11", 396.9d);
        input.put("f_12", 9.14d);
        double prediction = mdl.predict(VectorUtils.of(input));
        assertEquals(21.164552741740483, prediction, 1e-5);
    }
}
Also used : CatboostRegressionModel(org.apache.ignite.ml.catboost.CatboostRegressionModel) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) ModelReader(org.apache.ignite.ml.inference.reader.ModelReader) HashMap(java.util.HashMap) FileSystemModelReader(org.apache.ignite.ml.inference.reader.FileSystemModelReader) URL(java.net.URL) Test(org.junit.Test)

Aggregations

URL (java.net.URL)1 HashMap (java.util.HashMap)1 CatboostRegressionModel (org.apache.ignite.ml.catboost.CatboostRegressionModel)1 FileSystemModelReader (org.apache.ignite.ml.inference.reader.FileSystemModelReader)1 ModelReader (org.apache.ignite.ml.inference.reader.ModelReader)1 Test (org.junit.Test)1