Search in sources :

Example 1 with KNNModel

use of org.apache.ignite.ml.knn.models.KNNModel in project ignite by apache.

the class KNNClassificationExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String[] args) throws InterruptedException {
    System.out.println(">>> kNN classification example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), KNNClassificationExample.class.getSimpleName(), () -> {
            try {
                // Prepare path to read
                File file = IgniteUtils.resolveIgnitePath(KNN_IRIS_TXT);
                if (file == null)
                    throw new RuntimeException("Can't find file: " + KNN_IRIS_TXT);
                Path path = file.toPath();
                // Read dataset from file
                LabeledDataset dataset = LabeledDatasetLoader.loadFromTxtFile(path, SEPARATOR, true, false);
                // Random splitting of iris data as 70% train and 30% test datasets
                LabeledDatasetTestTrainPair split = new LabeledDatasetTestTrainPair(dataset, 0.3);
                System.out.println("\n>>> Amount of observations in train dataset " + split.train().rowSize());
                System.out.println("\n>>> Amount of observations in test dataset " + split.test().rowSize());
                LabeledDataset test = split.test();
                LabeledDataset train = split.train();
                KNNModel knnMdl = new KNNModel(5, new EuclideanDistance(), KNNStrategy.SIMPLE, train);
                // Clone labels
                final double[] labels = test.labels();
                // Save predicted classes to test dataset
                LabellingMachine.assignLabels(test, knnMdl);
                // Calculate amount of errors on test dataset
                int amountOfErrors = 0;
                for (int i = 0; i < test.rowSize(); i++) {
                    if (test.label(i) != labels[i])
                        amountOfErrors++;
                }
                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
                System.out.println("\n>>> Accuracy " + amountOfErrors / (double) test.rowSize());
                // Build confusion matrix. See https://en.wikipedia.org/wiki/Confusion_matrix
                int[][] confusionMtx = { { 0, 0, 0 }, { 0, 0, 0 }, { 0, 0, 0 } };
                for (int i = 0; i < test.rowSize(); i++) {
                    int idx1 = (int) test.label(i);
                    int idx2 = (int) labels[i];
                    confusionMtx[idx1 - 1][idx2 - 1]++;
                }
                System.out.println("\n>>> Confusion matrix is " + Arrays.deepToString(confusionMtx));
                // Calculate precision, recall and F-metric for each class
                for (int i = 0; i < 3; i++) {
                    double precision = 0.0;
                    for (int j = 0; j < 3; j++) precision += confusionMtx[i][j];
                    precision = confusionMtx[i][i] / precision;
                    double clsLb = (double) (i + 1);
                    System.out.println("\n>>> Precision for class " + clsLb + " is " + precision);
                    double recall = 0.0;
                    for (int j = 0; j < 3; j++) recall += confusionMtx[j][i];
                    recall = confusionMtx[i][i] / recall;
                    System.out.println("\n>>> Recall for class " + clsLb + " is " + recall);
                    double fScore = 2 * precision * recall / (precision + recall);
                    System.out.println("\n>>> F-score for class " + clsLb + " is " + fScore);
                }
            } catch (IOException e) {
                e.printStackTrace();
                System.out.println("\n>>> Unexpected exception, check resources: " + e);
            } finally {
                System.out.println("\n>>> kNN classification example completed.");
            }
        });
        igniteThread.start();
        igniteThread.join();
    }
}
Also used : Path(java.nio.file.Path) IOException(java.io.IOException) LabeledDataset(org.apache.ignite.ml.structures.LabeledDataset) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) LabeledDatasetTestTrainPair(org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair) KNNModel(org.apache.ignite.ml.knn.models.KNNModel) Ignite(org.apache.ignite.Ignite) IgniteThread(org.apache.ignite.thread.IgniteThread) File(java.io.File)

Example 2 with KNNModel

use of org.apache.ignite.ml.knn.models.KNNModel in project ignite by apache.

the class LocalModelsTest method importExportKNNModelTest.

/**
 */
@Test
public void importExportKNNModelTest() throws IOException {
    executeModelTest(mdlFilePath -> {
        double[][] mtx = new double[][] { { 1.0, 1.0 }, { 1.0, 2.0 }, { 2.0, 1.0 }, { -1.0, -1.0 }, { -1.0, -2.0 }, { -2.0, -1.0 } };
        double[] lbs = new double[] { 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 };
        LabeledDataset training = new LabeledDataset(mtx, lbs);
        KNNModel mdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
        Exporter<KNNModelFormat, String> exporter = new FileExporter<>();
        mdl.saveModel(exporter, mdlFilePath);
        KNNModelFormat load = exporter.load(mdlFilePath);
        Assert.assertNotNull(load);
        KNNModel importedMdl = new KNNModel(load.getK(), load.getDistanceMeasure(), load.getStgy(), load.getTraining());
        Assert.assertTrue("", mdl.equals(importedMdl));
        return null;
    });
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) KNNModelFormat(org.apache.ignite.ml.knn.models.KNNModelFormat) KNNModel(org.apache.ignite.ml.knn.models.KNNModel) LabeledDataset(org.apache.ignite.ml.structures.LabeledDataset) Test(org.junit.Test)

Example 3 with KNNModel

use of org.apache.ignite.ml.knn.models.KNNModel in project ignite by apache.

the class KNNClassificationTest method testBinaryClassificationWithSmallestKTest.

/**
 */
public void testBinaryClassificationWithSmallestKTest() {
    IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
    double[][] mtx = new double[][] { { 1.0, 1.0 }, { 1.0, 2.0 }, { 2.0, 1.0 }, { -1.0, -1.0 }, { -1.0, -2.0 }, { -2.0, -1.0 } };
    double[] lbs = new double[] { 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 };
    LabeledDataset training = new LabeledDataset(mtx, lbs);
    KNNModel knnMdl = new KNNModel(1, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
    Vector firstVector = new DenseLocalOnHeapVector(new double[] { 2.0, 2.0 });
    assertEquals(knnMdl.apply(firstVector), 1.0);
    Vector secondVector = new DenseLocalOnHeapVector(new double[] { -2.0, -2.0 });
    assertEquals(knnMdl.apply(secondVector), 2.0);
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) KNNModel(org.apache.ignite.ml.knn.models.KNNModel) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) LabeledDataset(org.apache.ignite.ml.structures.LabeledDataset) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Example 4 with KNNModel

use of org.apache.ignite.ml.knn.models.KNNModel in project ignite by apache.

the class KNNClassificationTest method testBinaryClassificationFarPointsWithSimpleStrategy.

/**
 */
public void testBinaryClassificationFarPointsWithSimpleStrategy() {
    IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
    double[][] mtx = new double[][] { { 10.0, 10.0 }, { 10.0, 20.0 }, { -1, -1 }, { -2, -2 }, { -1.0, -2.0 }, { -2.0, -1.0 } };
    double[] lbs = new double[] { 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 };
    LabeledDataset training = new LabeledDataset(mtx, lbs);
    KNNModel knnMdl = new KNNModel(3, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
    Vector vector = new DenseLocalOnHeapVector(new double[] { -1.01, -1.01 });
    assertEquals(knnMdl.apply(vector), 2.0);
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) KNNModel(org.apache.ignite.ml.knn.models.KNNModel) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) LabeledDataset(org.apache.ignite.ml.structures.LabeledDataset) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Example 5 with KNNModel

use of org.apache.ignite.ml.knn.models.KNNModel in project ignite by apache.

the class KNNClassificationTest method testLargeKValue.

/**
 */
public void testLargeKValue() {
    IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
    double[][] mtx = new double[][] { { 10.0, 10.0 }, { 10.0, 20.0 }, { -1, -1 }, { -2, -2 }, { -1.0, -2.0 }, { -2.0, -1.0 } };
    double[] lbs = new double[] { 1.0, 1.0, 1.0, 2.0, 2.0, 2.0 };
    LabeledDataset training = new LabeledDataset(mtx, lbs);
    try {
        new KNNModel(7, new EuclideanDistance(), KNNStrategy.SIMPLE, training);
        fail("SmallTrainingDatasetSizeException");
    } catch (SmallTrainingDatasetSizeException e) {
        return;
    }
    fail("SmallTrainingDatasetSizeException");
}
Also used : EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) SmallTrainingDatasetSizeException(org.apache.ignite.ml.math.exceptions.knn.SmallTrainingDatasetSizeException) KNNModel(org.apache.ignite.ml.knn.models.KNNModel) LabeledDataset(org.apache.ignite.ml.structures.LabeledDataset)

Aggregations

KNNModel (org.apache.ignite.ml.knn.models.KNNModel)10 LabeledDataset (org.apache.ignite.ml.structures.LabeledDataset)10 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)9 Vector (org.apache.ignite.ml.math.Vector)5 DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)5 LabeledDatasetTestTrainPair (org.apache.ignite.ml.structures.LabeledDatasetTestTrainPair)3 IgniteThread (org.apache.ignite.thread.IgniteThread)3 File (java.io.File)1 IOException (java.io.IOException)1 Path (java.nio.file.Path)1 Ignite (org.apache.ignite.Ignite)1 KNNModelFormat (org.apache.ignite.ml.knn.models.KNNModelFormat)1 KNNMultipleLinearRegression (org.apache.ignite.ml.knn.regression.KNNMultipleLinearRegression)1 ManhattanDistance (org.apache.ignite.ml.math.distances.ManhattanDistance)1 SmallTrainingDatasetSizeException (org.apache.ignite.ml.math.exceptions.knn.SmallTrainingDatasetSizeException)1 Test (org.junit.Test)1