Search in sources :

Example 1 with ANNClassificationModel

use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.

the class ANNClassificationExportImportExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println();
    System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, double[]> dataCache = null;
        Path jsonMdlPath = null;
        try {
            dataCache = getTestCache(ignite);
            ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
            ANNClassificationModel mdl = (ANNClassificationModel) trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
            System.out.println("\n>>> Exported ANN model: " + mdl.toString(true));
            double accuracy = evaluateModel(dataCache, mdl);
            System.out.println("\n>>> Accuracy for exported ANN model:" + accuracy);
            jsonMdlPath = Files.createTempFile(null, null);
            mdl.toJSON(jsonMdlPath);
            ANNClassificationModel modelImportedFromJSON = ANNClassificationModel.fromJSON(jsonMdlPath);
            System.out.println("\n>>> Imported ANN model: " + modelImportedFromJSON.toString(true));
            accuracy = evaluateModel(dataCache, modelImportedFromJSON);
            System.out.println("\n>>> Accuracy for imported ANN model:" + accuracy);
            System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
            if (jsonMdlPath != null)
                Files.deleteIfExists(jsonMdlPath);
        }
    } finally {
        System.out.flush();
    }
}
Also used : Path(java.nio.file.Path) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) Ignite(org.apache.ignite.Ignite) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance)

Example 2 with ANNClassificationModel

use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.

the class ANNClassificationTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
    ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
    ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.update(originalMdl, cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.update(originalMdl, new HashMap<>(), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    Assert.assertNotNull(updatedOnSameDataset.getCandidates());
    assertTrue(updatedOnSameDataset.toString().contains("weighted = [false]"));
    assertTrue(updatedOnSameDataset.toString(true).contains("weighted = [false]"));
    assertTrue(updatedOnSameDataset.toString(false).contains("weighted = [false]"));
    assertNotNull(updatedOnEmptyDataset.getCandidates());
    assertTrue(updatedOnEmptyDataset.toString().contains("weighted = [false]"));
    assertTrue(updatedOnEmptyDataset.toString(true).contains("weighted = [false]"));
    assertTrue(updatedOnEmptyDataset.toString(false).contains("weighted = [false]"));
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 3 with ANNClassificationModel

use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.

the class ANNClassificationTest method testBinaryClassification.

/**
 */
@Test
public void testBinaryClassification() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
    ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
    Assert.assertEquals(10, trainer.getK());
    Assert.assertEquals(10, trainer.getMaxIterations());
    TestUtils.assertEquals(1e-4, trainer.getEpsilon(), PRECISION);
    Assert.assertEquals(new EuclideanDistance(), trainer.getDistance());
    NNClassificationModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    Assert.assertNotNull(((ANNClassificationModel) mdl).getCandidates());
    assertTrue(mdl.toString().contains("weighted = [false]"));
    assertTrue(mdl.toString(true).contains("weighted = [false]"));
    assertTrue(mdl.toString(false).contains("weighted = [false]"));
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) HashMap(java.util.HashMap) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 4 with ANNClassificationModel

use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.

the class CollectionsTest method test.

/**
 */
@Test
@SuppressWarnings("unchecked")
public void test() {
    test(new VectorizedViewMatrix(new DenseMatrix(2, 2), 1, 1, 1, 1), new VectorizedViewMatrix(new DenseMatrix(3, 2), 2, 1, 1, 1));
    specialTest(new ManhattanDistance(), new ManhattanDistance());
    specialTest(new HammingDistance(), new HammingDistance());
    specialTest(new EuclideanDistance(), new EuclideanDistance());
    FeatureMetadata data = new FeatureMetadata("name2");
    data.setName("name1");
    test(data, new FeatureMetadata("name2"));
    test(new DatasetRow<>(new DenseVector()), new DatasetRow<>(new DenseVector(1)));
    test(new LabeledVector<>(new DenseVector(), null), new LabeledVector<>(new DenseVector(1), null));
    test(new Dataset<DatasetRow<Vector>>(new DatasetRow[] {}, new FeatureMetadata[] {}), new Dataset<DatasetRow<Vector>>(new DatasetRow[] { new DatasetRow() }, new FeatureMetadata[] { new FeatureMetadata() }));
    test(new LogisticRegressionModel(new DenseVector(), 1.0), new LogisticRegressionModel(new DenseVector(), 0.5));
    test(new KMeansModelFormat(new Vector[] {}, new ManhattanDistance()), new KMeansModelFormat(new Vector[] {}, new HammingDistance()));
    test(new KMeansModel(new Vector[] {}, new ManhattanDistance()), new KMeansModel(new Vector[] {}, new HammingDistance()));
    test(new SVMLinearClassificationModel(null, 1.0), new SVMLinearClassificationModel(null, 0.5));
    test(new ANNClassificationModel(new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNClassificationModel(new LabeledVectorSet<>(1, 1), new ANNClassificationTrainer.CentroidStat()));
    test(new ANNModelFormat(1, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()), new ANNModelFormat(2, new ManhattanDistance(), false, new LabeledVectorSet<>(), new ANNClassificationTrainer.CentroidStat()));
}
Also used : FeatureMetadata(org.apache.ignite.ml.structures.FeatureMetadata) HammingDistance(org.apache.ignite.ml.math.distances.HammingDistance) KMeansModel(org.apache.ignite.ml.clustering.kmeans.KMeansModel) LogisticRegressionModel(org.apache.ignite.ml.regressions.logistic.LogisticRegressionModel) ANNModelFormat(org.apache.ignite.ml.knn.ann.ANNModelFormat) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) KMeansModelFormat(org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) DatasetRow(org.apache.ignite.ml.structures.DatasetRow) VectorizedViewMatrix(org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) SVMLinearClassificationModel(org.apache.ignite.ml.svm.SVMLinearClassificationModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Example 5 with ANNClassificationModel

use of org.apache.ignite.ml.knn.ann.ANNClassificationModel in project ignite by apache.

the class LocalModelsTest method importExportANNModelTest.

/**
 */
@Test
public void importExportANNModelTest() throws IOException {
    executeModelTest(mdlFilePath -> {
        final LabeledVectorSet<LabeledVector> centers = new LabeledVectorSet<>();
        NNClassificationModel mdl = new ANNClassificationModel(centers, new ANNClassificationTrainer.CentroidStat()).withK(4).withDistanceMeasure(new ManhattanDistance()).withWeighted(true);
        Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
        mdl.saveModel(exporter, mdlFilePath);
        ANNModelFormat load = (ANNModelFormat) exporter.load(mdlFilePath);
        Assert.assertNotNull(load);
        NNClassificationModel importedMdl = new ANNClassificationModel(load.getCandidates(), new ANNClassificationTrainer.CentroidStat()).withK(load.getK()).withDistanceMeasure(load.getDistanceMeasure()).withWeighted(true);
        Assert.assertEquals("", mdl, importedMdl);
        return null;
    });
}
Also used : ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) NNClassificationModel(org.apache.ignite.ml.knn.NNClassificationModel) FileExporter(org.apache.ignite.ml.FileExporter) KNNModelFormat(org.apache.ignite.ml.knn.ann.KNNModelFormat) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) ANNModelFormat(org.apache.ignite.ml.knn.ann.ANNModelFormat) LabeledVectorSet(org.apache.ignite.ml.structures.LabeledVectorSet) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance) Test(org.junit.Test)

Aggregations

ANNClassificationModel (org.apache.ignite.ml.knn.ann.ANNClassificationModel)5 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)4 Test (org.junit.Test)4 ANNClassificationTrainer (org.apache.ignite.ml.knn.ann.ANNClassificationTrainer)3 ManhattanDistance (org.apache.ignite.ml.math.distances.ManhattanDistance)3 HashMap (java.util.HashMap)2 TrainerTest (org.apache.ignite.ml.common.TrainerTest)2 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)2 ANNModelFormat (org.apache.ignite.ml.knn.ann.ANNModelFormat)2 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)2 LabeledVectorSet (org.apache.ignite.ml.structures.LabeledVectorSet)2 Path (java.nio.file.Path)1 Ignite (org.apache.ignite.Ignite)1 FileExporter (org.apache.ignite.ml.FileExporter)1 KMeansModel (org.apache.ignite.ml.clustering.kmeans.KMeansModel)1 KMeansModelFormat (org.apache.ignite.ml.clustering.kmeans.KMeansModelFormat)1 NNClassificationModel (org.apache.ignite.ml.knn.NNClassificationModel)1 KNNModelFormat (org.apache.ignite.ml.knn.ann.KNNModelFormat)1 HammingDistance (org.apache.ignite.ml.math.distances.HammingDistance)1 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)1