Search in sources :

Example 1 with ANNClassificationTrainer

use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.

the class ANNClassificationExportImportExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws IOException {
    System.out.println();
    System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, double[]> dataCache = null;
        Path jsonMdlPath = null;
        try {
            dataCache = getTestCache(ignite);
            ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
            ANNClassificationModel mdl = (ANNClassificationModel) trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
            System.out.println("\n>>> Exported ANN model: " + mdl.toString(true));
            double accuracy = evaluateModel(dataCache, mdl);
            System.out.println("\n>>> Accuracy for exported ANN model:" + accuracy);
            jsonMdlPath = Files.createTempFile(null, null);
            mdl.toJSON(jsonMdlPath);
            ANNClassificationModel modelImportedFromJSON = ANNClassificationModel.fromJSON(jsonMdlPath);
            System.out.println("\n>>> Imported ANN model: " + modelImportedFromJSON.toString(true));
            accuracy = evaluateModel(dataCache, modelImportedFromJSON);
            System.out.println("\n>>> Accuracy for imported ANN model:" + accuracy);
            System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
        } finally {
            if (dataCache != null)
                dataCache.destroy();
            if (jsonMdlPath != null)
                Files.deleteIfExists(jsonMdlPath);
        }
    } finally {
        System.out.flush();
    }
}
Also used : Path(java.nio.file.Path) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) Ignite(org.apache.ignite.Ignite) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance)

Example 2 with ANNClassificationTrainer

use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.

the class ANNClassificationExample method main.

/**
 * Run example.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, double[]> dataCache = null;
        try {
            dataCache = getTestCache(ignite);
            ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
            long startTrainingTime = System.currentTimeMillis();
            NNClassificationModel knnMdl = trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
            long endTrainingTime = System.currentTimeMillis();
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            int amountOfErrors = 0;
            int totalAmount = 0;
            long totalPredictionTime = 0L;
            try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
                for (Cache.Entry<Integer, double[]> observation : observations) {
                    double[] val = observation.getValue();
                    double[] inputs = Arrays.copyOfRange(val, 1, val.length);
                    double groundTruth = val[0];
                    long startPredictionTime = System.currentTimeMillis();
                    double prediction = knnMdl.predict(new DenseVector(inputs));
                    long endPredictionTime = System.currentTimeMillis();
                    totalPredictionTime += (endPredictionTime - startPredictionTime);
                    totalAmount++;
                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
                        amountOfErrors++;
                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                }
                System.out.println(">>> ---------------------------------");
                System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
                System.out.println("Prediction costs = " + totalPredictionTime);
                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
                System.out.println(totalAmount);
                System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
            }
        } finally {
            dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : NNClassificationModel(org.apache.ignite.ml.knn.NNClassificationModel) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) Ignite(org.apache.ignite.Ignite) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) IgniteCache(org.apache.ignite.IgniteCache) Cache(javax.cache.Cache)

Example 3 with ANNClassificationTrainer

use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.

the class ANNClassificationTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
    ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
    ANNClassificationModel originalMdl = (ANNClassificationModel) trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    ANNClassificationModel updatedOnSameDataset = (ANNClassificationModel) trainer.update(originalMdl, cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    ANNClassificationModel updatedOnEmptyDataset = (ANNClassificationModel) trainer.update(originalMdl, new HashMap<>(), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    Assert.assertNotNull(updatedOnSameDataset.getCandidates());
    assertTrue(updatedOnSameDataset.toString().contains("weighted = [false]"));
    assertTrue(updatedOnSameDataset.toString(true).contains("weighted = [false]"));
    assertTrue(updatedOnSameDataset.toString(false).contains("weighted = [false]"));
    assertNotNull(updatedOnEmptyDataset.getCandidates());
    assertTrue(updatedOnEmptyDataset.toString().contains("weighted = [false]"));
    assertTrue(updatedOnEmptyDataset.toString(true).contains("weighted = [false]"));
    assertTrue(updatedOnEmptyDataset.toString(false).contains("weighted = [false]"));
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) HashMap(java.util.HashMap) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 4 with ANNClassificationTrainer

use of org.apache.ignite.ml.knn.ann.ANNClassificationTrainer in project ignite by apache.

the class ANNClassificationTest method testBinaryClassification.

/**
 */
@Test
public void testBinaryClassification() {
    Map<Integer, double[]> cacheMock = new HashMap<>();
    for (int i = 0; i < twoClusters.length; i++) cacheMock.put(i, twoClusters[i]);
    ANNClassificationTrainer trainer = new ANNClassificationTrainer().withK(10).withMaxIterations(10).withEpsilon(1e-4).withDistance(new EuclideanDistance());
    Assert.assertEquals(10, trainer.getK());
    Assert.assertEquals(10, trainer.getMaxIterations());
    TestUtils.assertEquals(1e-4, trainer.getEpsilon(), PRECISION);
    Assert.assertEquals(new EuclideanDistance(), trainer.getDistance());
    NNClassificationModel mdl = trainer.fit(cacheMock, parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(3).withDistanceMeasure(new EuclideanDistance()).withWeighted(false);
    Assert.assertNotNull(((ANNClassificationModel) mdl).getCandidates());
    assertTrue(mdl.toString().contains("weighted = [false]"));
    assertTrue(mdl.toString(true).contains("weighted = [false]"));
    assertTrue(mdl.toString(false).contains("weighted = [false]"));
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) ANNClassificationModel(org.apache.ignite.ml.knn.ann.ANNClassificationModel) HashMap(java.util.HashMap) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Aggregations

ANNClassificationTrainer (org.apache.ignite.ml.knn.ann.ANNClassificationTrainer)4 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)4 ANNClassificationModel (org.apache.ignite.ml.knn.ann.ANNClassificationModel)3 HashMap (java.util.HashMap)2 Ignite (org.apache.ignite.Ignite)2 TrainerTest (org.apache.ignite.ml.common.TrainerTest)2 DoubleArrayVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer)2 ManhattanDistance (org.apache.ignite.ml.math.distances.ManhattanDistance)2 Test (org.junit.Test)2 Path (java.nio.file.Path)1 Cache (javax.cache.Cache)1 IgniteCache (org.apache.ignite.IgniteCache)1 NNClassificationModel (org.apache.ignite.ml.knn.NNClassificationModel)1 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)1