Search in sources :

Example 1 with LabeledDummyVectorizer

use of org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer in project ignite by apache.

the class DecisionTreeClassificationExportImportExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String[] args) throws IOException {
    System.out.println(">>> Decision tree classification trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println("\n>>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        Path jsonMdlPath = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            Random rnd = new Random(0);
            // Fill training data.
            for (int i = 0; i < 1000; i++) trainingSet.put(i, generatePoint(rnd));
            // Create classification trainer.
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            // Train decision tree model.
            LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
            DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, vectorizer);
            System.out.println("\n>>> Exported Decision tree classification model: " + mdl);
            int correctPredictions = evaluateModel(rnd, mdl);
            System.out.println("\n>>> Accuracy for exported Decision tree classification model: " + correctPredictions / 10.0 + "%");
            jsonMdlPath = Files.createTempFile(null, null);
            mdl.toJSON(jsonMdlPath);
            DecisionTreeModel modelImportedFromJSON = DecisionTreeModel.fromJSON(jsonMdlPath);
            System.out.println("\n>>> Imported Decision tree classification model: " + modelImportedFromJSON);
            correctPredictions = evaluateModel(rnd, modelImportedFromJSON);
            System.out.println("\n>>> Accuracy for imported Decision tree classification model: " + correctPredictions / 10.0 + "%");
            System.out.println("\n>>> Decision tree classification trainer example completed.");
        } finally {
            if (trainingSet != null)
                trainingSet.destroy();
            if (jsonMdlPath != null)
                Files.deleteIfExists(jsonMdlPath);
        }
    } finally {
        System.out.flush();
    }
}
Also used : Path(java.nio.file.Path) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) Random(java.util.Random) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 2 with LabeledDummyVectorizer

use of org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer in project ignite by apache.

the class CrossValidationExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String... args) {
    System.out.println(">>> Cross validation score calculator example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            Random rnd = new Random(0);
            // Fill training data.
            for (int i = 0; i < 1000; i++) trainingSet.put(i, generatePoint(rnd));
            // Create classification trainer.
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
            CrossValidation<DecisionTreeModel, Integer, LabeledVector<Double>> scoreCalculator = new CrossValidation<>();
            double[] accuracyScores = scoreCalculator.withIgnite(ignite).withUpstreamCache(trainingSet).withTrainer(trainer).withMetric(MetricName.ACCURACY).withPreprocessor(vectorizer).withAmountOfFolds(4).isRunningOnPipeline(false).scoreByFolds();
            System.out.println(">>> Accuracy: " + Arrays.toString(accuracyScores));
            double[] balancedAccuracyScores = scoreCalculator.withMetric(MetricName.ACCURACY).scoreByFolds();
            System.out.println(">>> Balanced Accuracy: " + Arrays.toString(balancedAccuracyScores));
            System.out.println(">>> Cross validation score calculator example completed.");
        } finally {
            trainingSet.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) Random(java.util.Random) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CrossValidation(org.apache.ignite.ml.selection.cv.CrossValidation) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 3 with LabeledDummyVectorizer

use of org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer in project ignite by apache.

the class RandomForestClassifierTrainerTest method testUpdate.

/**
 */
@Test
public void testUpdate() {
    int sampleSize = 1000;
    Map<Integer, LabeledVector<Double>> sample = new HashMap<>();
    for (int i = 0; i < sampleSize; i++) {
        double x1 = i;
        double x2 = x1 / 10.0;
        double x3 = x2 / 10.0;
        double x4 = x3 / 10.0;
        sample.put(i, VectorUtils.of(x1, x2, x3, x4).labeled((double) i % 2));
    }
    ArrayList<FeatureMeta> meta = new ArrayList<>();
    for (int i = 0; i < 4; i++) meta.add(new FeatureMeta("", i, false));
    DatasetTrainer<RandomForestModel, Double> trainer = new RandomForestClassifierTrainer(meta).withAmountOfTrees(100).withFeaturesCountSelectionStrgy(x -> 2).withEnvironmentBuilder(TestUtils.testEnvBuilder());
    RandomForestModel originalMdl = trainer.fit(sample, parts, new LabeledDummyVectorizer<>());
    RandomForestModel updatedOnSameDS = trainer.update(originalMdl, sample, parts, new LabeledDummyVectorizer<>());
    RandomForestModel updatedOnEmptyDS = trainer.update(originalMdl, new HashMap<Integer, LabeledVector<Double>>(), parts, new LabeledDummyVectorizer<>());
    Vector v = VectorUtils.of(5, 0.5, 0.05, 0.005);
    assertEquals(originalMdl.predict(v), updatedOnSameDS.predict(v), 0.01);
    assertEquals(originalMdl.predict(v), updatedOnEmptyDS.predict(v), 0.01);
}
Also used : TrainerTest(org.apache.ignite.ml.common.TrainerTest) OnMajorityPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.OnMajorityPredictionsAggregator) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) Assert.assertTrue(org.junit.Assert.assertTrue) HashMap(java.util.HashMap) Test(org.junit.Test) DatasetTrainer(org.apache.ignite.ml.trainers.DatasetTrainer) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) ArrayList(java.util.ArrayList) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) Map(java.util.Map) Assert.assertEquals(org.junit.Assert.assertEquals) HashMap(java.util.HashMap) ArrayList(java.util.ArrayList) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) FeatureMeta(org.apache.ignite.ml.dataset.feature.FeatureMeta) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 4 with LabeledDummyVectorizer

use of org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer in project ignite by apache.

the class DecisionTreeClassificationTrainerExample method main.

/**
 * Executes example.
 *
 * @param args Command line arguments, none required.
 */
public static void main(String... args) {
    System.out.println(">>> Decision tree classification trainer example started.");
    // Start ignite grid.
    try (Ignite ignite = Ignition.start("examples/config/example-ignite.xml")) {
        System.out.println(">>> Ignite grid started.");
        // Create cache with training data.
        CacheConfiguration<Integer, LabeledVector<Double>> trainingSetCfg = new CacheConfiguration<>();
        trainingSetCfg.setName("TRAINING_SET");
        trainingSetCfg.setAffinity(new RendezvousAffinityFunction(false, 10));
        IgniteCache<Integer, LabeledVector<Double>> trainingSet = null;
        try {
            trainingSet = ignite.createCache(trainingSetCfg);
            Random rnd = new Random(0);
            // Fill training data.
            for (int i = 0; i < 1000; i++) trainingSet.put(i, generatePoint(rnd));
            // Create classification trainer.
            DecisionTreeClassificationTrainer trainer = new DecisionTreeClassificationTrainer(4, 0);
            // Train decision tree model.
            LabeledDummyVectorizer<Integer, Double> vectorizer = new LabeledDummyVectorizer<>();
            DecisionTreeModel mdl = trainer.fit(ignite, trainingSet, vectorizer);
            System.out.println(">>> Decision tree classification model: " + mdl);
            // Calculate score.
            int correctPredictions = 0;
            for (int i = 0; i < 1000; i++) {
                LabeledVector<Double> pnt = generatePoint(rnd);
                double prediction = mdl.predict(pnt.features());
                double lbl = pnt.label();
                if (i % 50 == 1)
                    System.out.printf(">>> test #: %d\t\t predicted: %.4f\t\tlabel: %.4f\n", i, prediction, lbl);
                if (Precision.equals(prediction, lbl, Precision.EPSILON))
                    correctPredictions++;
            }
            System.out.println(">>> Accuracy: " + correctPredictions / 10.0 + "%");
            System.out.println(">>> Decision tree classification trainer example completed.");
        } finally {
            trainingSet.destroy();
        }
    } finally {
        System.out.flush();
    }
}
Also used : DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DecisionTreeClassificationTrainer(org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer) Random(java.util.Random) Ignite(org.apache.ignite.Ignite) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration)

Example 5 with LabeledDummyVectorizer

use of org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer in project ignite by apache.

the class DataStreamGeneratorFillCacheTest method testCacheFilling.

/**
 */
@Test
public void testCacheFilling() {
    IgniteConfiguration configuration = new IgniteConfiguration().setDiscoverySpi(new TcpDiscoverySpi().setIpFinder(new TcpDiscoveryVmIpFinder().setAddresses(Arrays.asList("127.0.0.1:47500..47509"))));
    String cacheName = "TEST_CACHE";
    CacheConfiguration<UUID, LabeledVector<Double>> cacheConfiguration = new CacheConfiguration<UUID, LabeledVector<Double>>(cacheName).setAffinity(new RendezvousAffinityFunction(false, 10));
    int datasetSize = 5000;
    try (Ignite ignite = Ignition.start(configuration)) {
        IgniteCache<UUID, LabeledVector<Double>> cache = ignite.getOrCreateCache(cacheConfiguration);
        DataStreamGenerator generator = new GaussRandomProducer(0).vectorize(1).asDataStream();
        generator.fillCacheWithVecUUIDAsKey(datasetSize, cache);
        LabeledDummyVectorizer<UUID, Double> vectorizer = new LabeledDummyVectorizer<>();
        CacheBasedDatasetBuilder<UUID, LabeledVector<Double>> datasetBuilder = new CacheBasedDatasetBuilder<>(ignite, cache);
        IgniteFunction<SimpleDatasetData, StatPair> map = data -> new StatPair(DoubleStream.of(data.getFeatures()).sum(), data.getRows());
        LearningEnvironment env = LearningEnvironmentBuilder.defaultBuilder().buildForTrainer();
        env.deployingContext().initByClientObject(map);
        try (CacheBasedDataset<UUID, LabeledVector<Double>, EmptyContext, SimpleDatasetData> dataset = datasetBuilder.build(LearningEnvironmentBuilder.defaultBuilder(), new EmptyContextBuilder<>(), new SimpleDatasetDataBuilder<>(vectorizer), env)) {
            StatPair res = dataset.compute(map, StatPair::sum);
            assertEquals(datasetSize, res.cntOfRows);
            assertEquals(0.0, res.elementsSum / res.cntOfRows, 1e-2);
        }
        ignite.destroyCache(cacheName);
    }
}
Also used : Arrays(java.util.Arrays) IgniteFunction(org.apache.ignite.ml.math.functions.IgniteFunction) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) EmptyContextBuilder(org.apache.ignite.ml.dataset.primitive.builder.context.EmptyContextBuilder) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) LearningEnvironmentBuilder(org.apache.ignite.ml.environment.LearningEnvironmentBuilder) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) SimpleDatasetDataBuilder(org.apache.ignite.ml.dataset.primitive.builder.data.SimpleDatasetDataBuilder) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) Test(org.junit.Test) UUID(java.util.UUID) Ignite(org.apache.ignite.Ignite) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) IgniteCache(org.apache.ignite.IgniteCache) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) DoubleStream(java.util.stream.DoubleStream) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) Ignition(org.apache.ignite.Ignition) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) CacheBasedDataset(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDataset) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) TcpDiscoveryVmIpFinder(org.apache.ignite.spi.discovery.tcp.ipfinder.vm.TcpDiscoveryVmIpFinder) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) LabeledDummyVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) Ignite(org.apache.ignite.Ignite) UUID(java.util.UUID) TcpDiscoverySpi(org.apache.ignite.spi.discovery.tcp.TcpDiscoverySpi) GaussRandomProducer(org.apache.ignite.ml.util.generators.primitives.scalar.GaussRandomProducer) CacheBasedDatasetBuilder(org.apache.ignite.ml.dataset.impl.cache.CacheBasedDatasetBuilder) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) IgniteConfiguration(org.apache.ignite.configuration.IgniteConfiguration) SimpleDatasetData(org.apache.ignite.ml.dataset.primitive.data.SimpleDatasetData) GridCommonAbstractTest(org.apache.ignite.testframework.junits.common.GridCommonAbstractTest) Test(org.junit.Test)

Aggregations

LabeledDummyVectorizer (org.apache.ignite.ml.dataset.feature.extractor.impl.LabeledDummyVectorizer)5 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)5 Ignite (org.apache.ignite.Ignite)4 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)4 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)4 Random (java.util.Random)3 DecisionTreeClassificationTrainer (org.apache.ignite.ml.tree.DecisionTreeClassificationTrainer)3 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)3 Test (org.junit.Test)2 Path (java.nio.file.Path)1 ArrayList (java.util.ArrayList)1 Arrays (java.util.Arrays)1 HashMap (java.util.HashMap)1 Map (java.util.Map)1 UUID (java.util.UUID)1 DoubleStream (java.util.stream.DoubleStream)1 IgniteCache (org.apache.ignite.IgniteCache)1 Ignition (org.apache.ignite.Ignition)1 IgniteConfiguration (org.apache.ignite.configuration.IgniteConfiguration)1 TestUtils (org.apache.ignite.ml.TestUtils)1