Search in sources :

Example 1 with KNNModelFormat

use of org.apache.ignite.ml.knn.ann.KNNModelFormat in project ignite by apache.

the class LocalModelsTest method importExportANNModelTest.

/**
 */
@Test
public void importExportANNModelTest() throws IOException {
    executeModelTest(mdlFilePath -> {
        final LabeledVectorSet<LabeledVector> centers = new LabeledVectorSet<>();
        NNClassificationModel mdl = new ANNClassificationModel(centers, new ANNClassificationTrainer.CentroidStat()).withK(4).withDistanceMeasure(new ManhattanDistance()).withWeighted(true);
        Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
        mdl.saveModel(exporter, mdlFilePath);
        ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
        Assert.assertNotNull(load);
        NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates(), new ANNClassificationTrainer.CentroidStat()).withK(load.getK()).withDistanceMeasure(load.getDistanceMeasure()).withWeighted(true);
        Assert.assertEquals("", mdl, importedMdl);
        return null;
    });
}
Also used : ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) NNClassificationModel(org.apache.ignite.ml.knn.NNClassificationModel) FileExporter(org.apache.ignite.ml.FileExporter) KNNModelFormat(org.apache.ignite.ml.knn.ann.KNNModelFormat) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) ANNModelFormat(org.apache.ignite.ml.knn.ann.ANNModelFormat) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance) Test(org.junit.Test)

Aggregations

FileExporter (org.apache.ignite.ml.FileExporter)1 NNClassificationModel (org.apache.ignite.ml.knn.NNClassificationModel)1 ANNClassificationModel (org.apache.ignite.ml.knn.ann.ANNClassificationModel)1 ANNModelFormat (org.apache.ignite.ml.knn.ann.ANNModelFormat)1 KNNModelFormat (org.apache.ignite.ml.knn.ann.KNNModelFormat)1 ManhattanDistance (org.apache.ignite.ml.math.distances.ManhattanDistance)1 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)1 LabeledVectorSet (org.apache.ignite.ml.structures.LabeledVectorSet)1 Test (org.junit.Test)1