Search in sources :

Example 1 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class AlgorithmSpecificDatasetExample method createCache.

/**
 */
private static IgniteCache<Integer, Vector> createCache(Ignite ignite) {
    CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
    cacheConfiguration.setName("PERSONS");
    cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2));
    IgniteCache<Integer, Vector> persons = ignite.createCache(cacheConfiguration);
    persons.put(1, new DenseVector(new Serializable[] { "Mike", 42, 10000 }));
    persons.put(2, new DenseVector(new Serializable[] { "John", 32, 64000 }));
    persons.put(3, new DenseVector(new Serializable[] { "George", 53, 120000 }));
    persons.put(4, new DenseVector(new Serializable[] { "Karl", 24, 70000 }));
    return persons;
}
Also used : Serializable(java.io.Serializable) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 2 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class ANNClassificationExportImportExample method evaluateModel.

/**
 */
private static double evaluateModel(IgniteCache<Integer, double[]> dataCache, NNClassificationModel knnMdl) {
    int amountOfErrors = 0;
    int totalAmount = 0;
    double accuracy;
    try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
        System.out.println(">>> ---------------------------------");
        System.out.println(">>> | Prediction\t| Ground Truth\t|");
        System.out.println(">>> ---------------------------------");
        for (Cache.Entry<Integer, double[]> observation : observations) {
            double[] val = observation.getValue();
            double[] inputs = Arrays.copyOfRange(val, 1, val.length);
            double groundTruth = val[0];
            double prediction = knnMdl.predict(new DenseVector(inputs));
            totalAmount++;
            if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
                amountOfErrors++;
            System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
        }
        System.out.println(">>> ---------------------------------");
        accuracy = 1 - amountOfErrors / (double) totalAmount;
    }
    return accuracy;
}
Also used : DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) IgniteCache(org.apache.ignite.IgniteCache) Cache(javax.cache.Cache)

Example 3 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class ANNClassificationExample method main.

/**
 * Run example.
 */
public static void main(String[] args) {
    System.out.println();
    System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        IgniteCache<Integer, double[]> dataCache = null;
        try {
            dataCache = getTestCache(ignite);
            ANNClassificationTrainer trainer = new ANNClassificationTrainer().withDistance(new ManhattanDistance()).withK(50).withMaxIterations(1000).withEpsilon(1e-2);
            long startTrainingTime = System.currentTimeMillis();
            NNClassificationModel knnMdl = trainer.fit(ignite, dataCache, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.FIRST)).withK(5).withDistanceMeasure(new EuclideanDistance()).withWeighted(true);
            long endTrainingTime = System.currentTimeMillis();
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            int amountOfErrors = 0;
            int totalAmount = 0;
            long totalPredictionTime = 0L;
            try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
                for (Cache.Entry<Integer, double[]> observation : observations) {
                    double[] val = observation.getValue();
                    double[] inputs = Arrays.copyOfRange(val, 1, val.length);
                    double groundTruth = val[0];
                    long startPredictionTime = System.currentTimeMillis();
                    double prediction = knnMdl.predict(new DenseVector(inputs));
                    long endPredictionTime = System.currentTimeMillis();
                    totalPredictionTime += (endPredictionTime - startPredictionTime);
                    totalAmount++;
                    if (!Precision.equals(groundTruth, prediction, Precision.EPSILON))
                        amountOfErrors++;
                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                }
                System.out.println(">>> ---------------------------------");
                System.out.println("Training costs = " + (endTrainingTime - startTrainingTime));
                System.out.println("Prediction costs = " + totalPredictionTime);
                System.out.println("\n>>> Absolute amount of errors " + amountOfErrors);
                System.out.println("\n>>> Accuracy " + (1 - amountOfErrors / (double) totalAmount));
                System.out.println(totalAmount);
                System.out.println(">>> ANN multi-class classification algorithm over cached dataset usage example completed.");
            }
        } finally {
            dataCache.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : NNClassificationModel(org.apache.ignite.ml.knn.NNClassificationModel) EuclideanDistance(org.apache.ignite.ml.math.distances.EuclideanDistance) Ignite(org.apache.ignite.Ignite) ANNClassificationTrainer(org.apache.ignite.ml.knn.ann.ANNClassificationTrainer) ManhattanDistance(org.apache.ignite.ml.math.distances.ManhattanDistance) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) IgniteCache(org.apache.ignite.IgniteCache) Cache(javax.cache.Cache)

Example 4 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class TitanicUtils method readPassengersWithoutNulls.

/**
 * Read passengers data from csv file.
 *
 * @param ignite The ignite.
 * @return The filled cache.
 * @throws FileNotFoundException If data file is not found.
 */
public static IgniteCache<Integer, Vector> readPassengersWithoutNulls(Ignite ignite) throws FileNotFoundException {
    IgniteCache<Integer, Vector> cache = getCache(ignite);
    Scanner scanner = new Scanner(IgniteUtils.resolveIgnitePath("examples/src/main/resources/datasets/titanic_without_nulls.csv"));
    int cnt = 0;
    while (scanner.hasNextLine()) {
        String row = scanner.nextLine();
        if (cnt == 0) {
            cnt++;
            continue;
        }
        String[] cells = row.split(";");
        Serializable[] data = new Serializable[cells.length];
        NumberFormat format = NumberFormat.getInstance(Locale.FRANCE);
        for (int i = 0; i < cells.length; i++) try {
            data[i] = "".equals(cells[i]) ? Double.NaN : Double.valueOf(cells[i]);
        } catch (java.lang.NumberFormatException e) {
            try {
                data[i] = format.parse(cells[i]).doubleValue();
            } catch (ParseException e1) {
                data[i] = cells[i];
            }
        }
        cache.put(cnt++, new DenseVector(data));
    }
    return cache;
}
Also used : Scanner(java.util.Scanner) Serializable(java.io.Serializable) ParseException(java.text.ParseException) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) NumberFormat(java.text.NumberFormat)

Example 5 with DenseVector

use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.

the class ImputingExample method createCache.

/**
 */
private static IgniteCache<Integer, Vector> createCache(Ignite ignite) {
    CacheConfiguration<Integer, Vector> cacheConfiguration = new CacheConfiguration<>();
    cacheConfiguration.setName("PERSONS");
    cacheConfiguration.setAffinity(new RendezvousAffinityFunction(false, 2));
    IgniteCache<Integer, Vector> persons = ignite.createCache(cacheConfiguration);
    persons.put(1, new DenseVector(new Serializable[] { "Mike", 10, 1 }));
    persons.put(1, new DenseVector(new Serializable[] { "John", 20, 2 }));
    persons.put(1, new DenseVector(new Serializable[] { "George", 15, 1 }));
    persons.put(1, new DenseVector(new Serializable[] { "Piter", 25, Double.NaN }));
    persons.put(1, new DenseVector(new Serializable[] { "Karl", Double.NaN, 1 }));
    persons.put(1, new DenseVector(new Serializable[] { "Gustaw", 20, 2 }));
    persons.put(1, new DenseVector(new Serializable[] { "Alex", 20, 2 }));
    return persons;
}
Also used : Serializable(java.io.Serializable) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Aggregations

DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)101 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)59 Test (org.junit.Test)59 Serializable (java.io.Serializable)16 SparseVector (org.apache.ignite.ml.math.primitives.vector.impl.SparseVector)14 HashMap (java.util.HashMap)13 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)13 DummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.DummyVectorizer)10 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)10 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)9 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)9 HashSet (java.util.HashSet)7 TrainerTest (org.apache.ignite.ml.common.TrainerTest)7 KMeansModel (org.apache.ignite.ml.clustering.kmeans.KMeansModel)5 LocalDatasetBuilder (org.apache.ignite.ml.dataset.impl.local.LocalDatasetBuilder)5 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)5 IgniteDifferentiableVectorToDoubleFunction (org.apache.ignite.ml.math.functions.IgniteDifferentiableVectorToDoubleFunction)5 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)5 OneHotEncoderPreprocessor (org.apache.ignite.ml.preprocessing.encoding.onehotencoder.OneHotEncoderPreprocessor)4 Random (java.util.Random)3