Search in sources :

Example 46 with DenseLocalOnHeapVector

use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.

the class DistributedLinearRegressionWithLSQRTrainerExample method main.

/**
 * Run example.
 */
public static void main(String[] args) throws InterruptedException {
    System.out.println();
    System.out.println(">>> Linear regression model over sparse distributed matrix API usage example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Create IgniteThread, we must work with SparseDistributedMatrix inside IgniteThread
        // because we create ignite cache internally.
        IgniteThread igniteThread = new IgniteThread(ignite.configuration().getIgniteInstanceName(), SparseDistributedMatrixExample.class.getSimpleName(), () -> {
            IgniteCache<Integer, double[]> dataCache = getTestCache(ignite);
            System.out.println(">>> Create new linear regression trainer object.");
            LinearRegressionLSQRTrainer<Integer, double[]> trainer = new LinearRegressionLSQRTrainer<>();
            System.out.println(">>> Perform the training to get the model.");
            LinearRegressionModel mdl = trainer.fit(new CacheBasedDatasetBuilder<>(ignite, dataCache), (k, v) -> Arrays.copyOfRange(v, 1, v.length), (k, v) -> v[0], 4);
            System.out.println(">>> Linear regression model: " + mdl);
            System.out.println(">>> ---------------------------------");
            System.out.println(">>> | Prediction\t| Ground Truth\t|");
            System.out.println(">>> ---------------------------------");
            try (QueryCursor<Cache.Entry<Integer, double[]>> observations = dataCache.query(new ScanQuery<>())) {
                for (Cache.Entry<Integer, double[]> observation : observations) {
                    double[] val = observation.getValue();
                    double[] inputs = Arrays.copyOfRange(val, 1, val.length);
                    double groundTruth = val[0];
                    double prediction = mdl.apply(new DenseLocalOnHeapVector(inputs));
                    System.out.printf(">>> | %.4f\t\t| %.4f\t\t|\n", prediction, groundTruth);
                }
            }
            System.out.println(">>> ---------------------------------");
        });
        igniteThread.start();
        igniteThread.join();
    }
}
Also used : LinearRegressionModel(org.apache.ignite.ml.regressions.linear.LinearRegressionModel) LinearRegressionLSQRTrainer(org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer) SparseDistributedMatrixExample(org.apache.ignite.examples.ml.math.matrix.SparseDistributedMatrixExample) Ignite(org.apache.ignite.Ignite) IgniteThread(org.apache.ignite.thread.IgniteThread) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) IgniteCache(org.apache.ignite.IgniteCache) Cache(javax.cache.Cache)

Example 47 with DenseLocalOnHeapVector

use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.

the class DecisionTreesExample method main.

/**
 * Launches example.
 *
 * @param args Program arguments.
 */
public static void main(String[] args) throws IOException {
    System.out.println(">>> Decision trees example started.");
    String igniteCfgPath;
    CommandLineParser parser = new BasicParser();
    String trainingImagesPath;
    String trainingLabelsPath;
    String testImagesPath;
    String testLabelsPath;
    Map<String, String> mnistPaths = new HashMap<>();
    mnistPaths.put(MNIST_TRAIN_IMAGES, "train-images-idx3-ubyte");
    mnistPaths.put(MNIST_TRAIN_LABELS, "train-labels-idx1-ubyte");
    mnistPaths.put(MNIST_TEST_IMAGES, "t10k-images-idx3-ubyte");
    mnistPaths.put(MNIST_TEST_LABELS, "t10k-labels-idx1-ubyte");
    try {
        // Parse the command line arguments.
        CommandLine line = parser.parse(buildOptions(), args);
        if (line.hasOption(MLExamplesCommonArgs.UNATTENDED)) {
            System.out.println(">>> Skipped example execution because 'unattended' mode is used.");
            System.out.println(">>> Decision trees example finished.");
            return;
        }
        igniteCfgPath = line.getOptionValue(CONFIG, DEFAULT_CONFIG);
    } catch (ParseException e) {
        e.printStackTrace();
        return;
    }
    if (!getMNIST(mnistPaths.values())) {
        System.out.println(">>> You should have MNIST dataset in " + MNIST_DIR + " to run this example.");
        return;
    }
    trainingImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TRAIN_IMAGES))).getPath();
    trainingLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TRAIN_LABELS))).getPath();
    testImagesPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TEST_IMAGES))).getPath();
    testLabelsPath = Objects.requireNonNull(IgniteUtils.resolveIgnitePath(MNIST_DIR + "/" + mnistPaths.get(MNIST_TEST_LABELS))).getPath();
    try (Ignite ignite = Ignition.start(igniteCfgPath)) {
        IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
        int ptsCnt = 60000;
        int featCnt = 28 * 28;
        Stream<DenseLocalOnHeapVector> trainingMnistStream = MnistUtils.mnist(trainingImagesPath, trainingLabelsPath, new Random(123L), ptsCnt);
        Stream<DenseLocalOnHeapVector> testMnistStream = MnistUtils.mnist(testImagesPath, testLabelsPath, new Random(123L), 10_000);
        IgniteCache<BiIndex, Double> cache = createBiIndexedCache(ignite);
        loadVectorsIntoBiIndexedCache(cache.getName(), trainingMnistStream.iterator(), featCnt + 1, ignite);
        ColumnDecisionTreeTrainer<GiniSplitCalculator.GiniData> trainer = new ColumnDecisionTreeTrainer<>(10, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MOST_COMMON, ignite);
        System.out.println(">>> Training started");
        long before = System.currentTimeMillis();
        DecisionTreeModel mdl = trainer.train(new BiIndexedCacheColumnDecisionTreeTrainerInput(cache, new HashMap<>(), ptsCnt, featCnt));
        System.out.println(">>> Training finished in " + (System.currentTimeMillis() - before));
        IgniteTriFunction<Model<Vector, Double>, Stream<IgniteBiTuple<Vector, Double>>, Function<Double, Double>, Double> mse = Estimators.errorsPercentage();
        Double accuracy = mse.apply(mdl, testMnistStream.map(v -> new IgniteBiTuple<>(v.viewPart(0, featCnt), v.getX(featCnt))), Function.identity());
        System.out.println(">>> Errs percentage: " + accuracy);
    } catch (IOException e) {
        e.printStackTrace();
    }
    System.out.println(">>> Decision trees example finished.");
}
Also used : GZIPInputStream(java.util.zip.GZIPInputStream) URL(java.net.URL) Scanner(java.util.Scanner) Random(java.util.Random) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) ExampleNodeStartup(org.apache.ignite.examples.ExampleNodeStartup) Vector(org.apache.ignite.ml.math.Vector) Estimators(org.apache.ignite.ml.estimators.Estimators) Map(java.util.Map) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Collection(java.util.Collection) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) Collectors(java.util.stream.Collectors) IgniteCache(org.apache.ignite.IgniteCache) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) Objects(java.util.Objects) List(java.util.List) Stream(java.util.stream.Stream) ParseException(org.apache.commons.cli.ParseException) RegionCalculators(org.apache.ignite.ml.trees.trainers.columnbased.regcalcs.RegionCalculators) NotNull(org.jetbrains.annotations.NotNull) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) Options(org.apache.commons.cli.Options) HashMap(java.util.HashMap) Function(java.util.function.Function) GiniSplitCalculator(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.GiniSplitCalculator) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) OptionBuilder(org.apache.commons.cli.OptionBuilder) CacheWriteSynchronizationMode(org.apache.ignite.cache.CacheWriteSynchronizationMode) IgniteUtils(org.apache.ignite.internal.util.IgniteUtils) BasicParser(org.apache.commons.cli.BasicParser) CommandLine(org.apache.commons.cli.CommandLine) MnistUtils(org.apache.ignite.ml.util.MnistUtils) Option(org.apache.commons.cli.Option) ReadableByteChannel(java.nio.channels.ReadableByteChannel) Iterator(java.util.Iterator) ContinuousSplitCalculators(org.apache.ignite.ml.trees.trainers.columnbased.contsplitcalcs.ContinuousSplitCalculators) CommandLineParser(org.apache.commons.cli.CommandLineParser) Channels(java.nio.channels.Channels) FileOutputStream(java.io.FileOutputStream) IOException(java.io.IOException) FileInputStream(java.io.FileInputStream) Ignite(org.apache.ignite.Ignite) MLExamplesCommonArgs(org.apache.ignite.examples.ml.MLExamplesCommonArgs) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) IgniteDataStreamer(org.apache.ignite.IgniteDataStreamer) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) HashMap(java.util.HashMap) IgniteBiTuple(org.apache.ignite.lang.IgniteBiTuple) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) Function(java.util.function.Function) Random(java.util.Random) Ignite(org.apache.ignite.Ignite) GZIPInputStream(java.util.zip.GZIPInputStream) Stream(java.util.stream.Stream) FileOutputStream(java.io.FileOutputStream) FileInputStream(java.io.FileInputStream) CommandLineParser(org.apache.commons.cli.CommandLineParser) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) ColumnDecisionTreeTrainer(org.apache.ignite.ml.trees.trainers.columnbased.ColumnDecisionTreeTrainer) BiIndex(org.apache.ignite.ml.trees.trainers.columnbased.BiIndex) IOException(java.io.IOException) BiIndexedCacheColumnDecisionTreeTrainerInput(org.apache.ignite.ml.trees.trainers.columnbased.BiIndexedCacheColumnDecisionTreeTrainerInput) BasicParser(org.apache.commons.cli.BasicParser) CommandLine(org.apache.commons.cli.CommandLine) DecisionTreeModel(org.apache.ignite.ml.trees.models.DecisionTreeModel) Model(org.apache.ignite.ml.Model) ParseException(org.apache.commons.cli.ParseException) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)

Example 48 with DenseLocalOnHeapVector

use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.

the class AbstractMultipleLinearRegressionTest method testNewSample.

/**
     * Verifies that newSampleData methods consistently insert unitary columns
     * in design matrix.  Confirms the fix for MATH-411.
     */
@Test
public void testNewSample() {
    double[] design = new double[] { 1, 19, 22, 33, 2, 20, 30, 40, 3, 25, 35, 45, 4, 27, 37, 47 };
    double[] y = new double[] { 1, 2, 3, 4 };
    double[][] x = new double[][] { { 19, 22, 33 }, { 20, 30, 40 }, { 25, 35, 45 }, { 27, 37, 47 } };
    AbstractMultipleLinearRegression regression = createRegression();
    regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
    Matrix flatX = regression.getX().copy();
    Vector flatY = regression.getY().copy();
    regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
    regression.newYSampleData(new DenseLocalOnHeapVector(y));
    Assert.assertEquals(flatX, regression.getX());
    Assert.assertEquals(flatY, regression.getY());
    // No intercept
    regression.setNoIntercept(true);
    regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
    flatX = regression.getX().copy();
    flatY = regression.getY().copy();
    regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
    regression.newYSampleData(new DenseLocalOnHeapVector(y));
    Assert.assertEquals(flatX, regression.getX());
    Assert.assertEquals(flatY, regression.getY());
}
Also used : Matrix(org.apache.ignite.ml.math.Matrix) DenseLocalOnHeapMatrix(org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) DenseLocalOnHeapMatrix(org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Test(org.junit.Test)

Example 49 with DenseLocalOnHeapVector

use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.

the class OLSMultipleLinearRegressionTest method testMathIllegalArgumentException.

/** */
@Test(expected = MathIllegalArgumentException.class)
public void testMathIllegalArgumentException() {
    OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
    mdl.validateSampleData(new DenseLocalOnHeapMatrix(1, 2), new DenseLocalOnHeapVector(1));
}
Also used : DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) DenseLocalOnHeapMatrix(org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix) Test(org.junit.Test)

Example 50 with DenseLocalOnHeapVector

use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.

the class OLSMultipleLinearRegressionTest method testNewSample2.

/**
     * Verifies that setting X and Y separately has the same effect as newSample(X,Y).
     */
@Test
public void testNewSample2() {
    double[] y = new double[] { 1, 2, 3, 4 };
    double[][] x = new double[][] { { 19, 22, 33 }, { 20, 30, 40 }, { 25, 35, 45 }, { 27, 37, 47 } };
    OLSMultipleLinearRegression regression = new OLSMultipleLinearRegression();
    regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
    Matrix combinedX = regression.getX().copy();
    Vector combinedY = regression.getY().copy();
    regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
    regression.newYSampleData(new DenseLocalOnHeapVector(y));
    Assert.assertEquals(combinedX, regression.getX());
    Assert.assertEquals(combinedY, regression.getY());
    // No intercept
    regression.setNoIntercept(true);
    regression.newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
    combinedX = regression.getX().copy();
    combinedY = regression.getY().copy();
    regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
    regression.newYSampleData(new DenseLocalOnHeapVector(y));
    Assert.assertEquals(combinedX, regression.getX());
    Assert.assertEquals(combinedY, regression.getY());
}
Also used : Matrix(org.apache.ignite.ml.math.Matrix) DenseLocalOnHeapMatrix(org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) DenseLocalOnHeapMatrix(org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix) Vector(org.apache.ignite.ml.math.Vector) DenseLocalOnHeapVector(org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector) Test(org.junit.Test)

Aggregations

DenseLocalOnHeapVector (org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector)98 Vector (org.apache.ignite.ml.math.Vector)49 Test (org.junit.Test)44 DenseLocalOnHeapMatrix (org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix)26 Random (java.util.Random)18 HashMap (java.util.HashMap)17 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)14 Matrix (org.apache.ignite.ml.math.Matrix)12 SparseDistributedMatrix (org.apache.ignite.ml.math.impls.matrix.SparseDistributedMatrix)11 IgniteCache (org.apache.ignite.IgniteCache)8 LabeledDataset (org.apache.ignite.ml.structures.LabeledDataset)8 Arrays (java.util.Arrays)7 Collections (java.util.Collections)6 List (java.util.List)6 Map (java.util.Map)6 Collectors (java.util.stream.Collectors)6 Stream (java.util.stream.Stream)6 Ignite (org.apache.ignite.Ignite)6 IgniteUtils (org.apache.ignite.internal.util.IgniteUtils)6 IgniteBiTuple (org.apache.ignite.lang.IgniteBiTuple)6