Search in sources :

Example 1 with DecisionTreeData

use of org.apache.ignite.ml.tree.data.DecisionTreeData in project ignite by apache.

the class GDBOnTreesLearningStrategy method update.

/**
 * {@inheritDoc}
 */
@Override
public <K, V> List<IgniteModel<Vector, Double>> update(GDBModel mdlToUpdate, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> vectorizer) {
    LearningEnvironment environment = envBuilder.buildForTrainer();
    environment.initDeployingContext(vectorizer);
    DatasetTrainer<? extends IgniteModel<Vector, Double>, Double> trainer = baseMdlTrainerBuilder.get();
    assert trainer instanceof DecisionTreeTrainer;
    DecisionTreeTrainer decisionTreeTrainer = (DecisionTreeTrainer) trainer;
    List<IgniteModel<Vector, Double>> models = initLearningState(mdlToUpdate);
    ConvergenceChecker<K, V> convCheck = checkConvergenceStgyFactory.create(sampleSize, externalLbToInternalMapping, loss, datasetBuilder, vectorizer);
    try (Dataset<EmptyContext, DecisionTreeData> dataset = datasetBuilder.build(envBuilder, new EmptyContextBuilder<>(), new DecisionTreeDataBuilder<>(vectorizer, useIdx), environment)) {
        for (int i = 0; i < cntOfIterations; i++) {
            double[] weights = Arrays.copyOf(compositionWeights, models.size());
            WeightedPredictionsAggregator aggregator = new WeightedPredictionsAggregator(weights, meanLbVal);
            ModelsComposition currComposition = new ModelsComposition(models, aggregator);
            if (convCheck.isConverged(dataset, currComposition))
                break;
            dataset.compute(part -> {
                if (part.getCopiedOriginalLabels() == null)
                    part.setCopiedOriginalLabels(Arrays.copyOf(part.getLabels(), part.getLabels().length));
                for (int j = 0; j < part.getLabels().length; j++) {
                    double mdlAnswer = currComposition.predict(VectorUtils.of(part.getFeatures()[j]));
                    double originalLbVal = externalLbToInternalMapping.apply(part.getCopiedOriginalLabels()[j]);
                    part.getLabels()[j] = -loss.gradient(sampleSize, originalLbVal, mdlAnswer);
                }
            });
            long startTs = System.currentTimeMillis();
            models.add(decisionTreeTrainer.fit(dataset));
            double learningTime = (double) (System.currentTimeMillis() - startTs) / 1000.0;
            trainerEnvironment.logger(getClass()).log(MLLogger.VerboseLevel.LOW, "One model training time was %.2fs", learningTime);
        }
    } catch (Exception e) {
        throw new RuntimeException(e);
    }
    compositionWeights = Arrays.copyOf(compositionWeights, models.size());
    return models;
}
Also used : EmptyContext(org.apache.ignite.ml.dataset.primitive.context.EmptyContext) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) DecisionTreeTrainer(org.apache.ignite.ml.tree.DecisionTreeTrainer) LearningEnvironment(org.apache.ignite.ml.environment.LearningEnvironment) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) IgniteModel(org.apache.ignite.ml.IgniteModel) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Example 2 with DecisionTreeData

use of org.apache.ignite.ml.tree.data.DecisionTreeData in project ignite by apache.

the class MSEImpurityMeasureCalculatorTest method testCalculate.

/**
 */
@Test
public void testCalculate() {
    double[][] data = new double[][] { { 0, 2 }, { 1, 1 }, { 2, 0 }, { 3, 3 } };
    double[] labels = new double[] { 1, 2, 2, 1 };
    MSEImpurityMeasureCalculator calculator = new MSEImpurityMeasureCalculator(useIdx);
    StepFunction<MSEImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
    assertEquals(2, impurity.length);
    // Test MSE calculated for the first column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
    assertEquals(1.000, impurity[0].getY()[0].impurity(), 1e-3);
    assertEquals(0.666, impurity[0].getY()[1].impurity(), 1e-3);
    assertEquals(1.000, impurity[0].getY()[2].impurity(), 1e-3);
    assertEquals(0.666, impurity[0].getY()[3].impurity(), 1e-3);
    assertEquals(1.000, impurity[0].getY()[4].impurity(), 1e-3);
    // Test MSE calculated for the second column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[1].getX(), 1e-10);
    assertEquals(1.000, impurity[1].getY()[0].impurity(), 1e-3);
    assertEquals(0.666, impurity[1].getY()[1].impurity(), 1e-3);
    assertEquals(0.000, impurity[1].getY()[2].impurity(), 1e-3);
    assertEquals(0.666, impurity[1].getY()[3].impurity(), 1e-3);
    assertEquals(1.000, impurity[1].getY()[4].impurity(), 1e-3);
}
Also used : StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) Test(org.junit.Test)

Example 3 with DecisionTreeData

use of org.apache.ignite.ml.tree.data.DecisionTreeData in project ignite by apache.

the class GiniImpurityMeasureCalculatorTest method testCalculate.

/**
 */
@Test
public void testCalculate() {
    double[][] data = new double[][] { { 0, 1 }, { 1, 0 }, { 2, 2 }, { 3, 3 } };
    double[] labels = new double[] { 0, 1, 1, 1 };
    Map<Double, Integer> encoder = new HashMap<>();
    encoder.put(0.0, 0);
    encoder.put(1.0, 1);
    GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
    StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
    assertEquals(2, impurity.length);
    // Check Gini calculated for the first column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
    assertEquals(-2.500, impurity[0].getY()[0].impurity(), 1e-3);
    assertEquals(-4.000, impurity[0].getY()[1].impurity(), 1e-3);
    assertEquals(-3.000, impurity[0].getY()[2].impurity(), 1e-3);
    assertEquals(-2.666, impurity[0].getY()[3].impurity(), 1e-3);
    assertEquals(-2.500, impurity[0].getY()[4].impurity(), 1e-3);
    // Check Gini calculated for the second column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[1].getX(), 1e-10);
    assertEquals(-2.500, impurity[1].getY()[0].impurity(), 1e-3);
    assertEquals(-2.666, impurity[1].getY()[1].impurity(), 1e-3);
    assertEquals(-3.000, impurity[1].getY()[2].impurity(), 1e-3);
    assertEquals(-2.666, impurity[1].getY()[3].impurity(), 1e-3);
    assertEquals(-2.500, impurity[1].getY()[4].impurity(), 1e-3);
}
Also used : HashMap(java.util.HashMap) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) Test(org.junit.Test)

Example 4 with DecisionTreeData

use of org.apache.ignite.ml.tree.data.DecisionTreeData in project ignite by apache.

the class GiniImpurityMeasureCalculatorTest method testCalculateWithRepeatedData.

/**
 */
@Test
public void testCalculateWithRepeatedData() {
    double[][] data = new double[][] { { 0 }, { 1 }, { 2 }, { 2 }, { 3 } };
    double[] labels = new double[] { 0, 1, 1, 1, 1 };
    Map<Double, Integer> encoder = new HashMap<>();
    encoder.put(0.0, 0);
    encoder.put(1.0, 1);
    GiniImpurityMeasureCalculator calculator = new GiniImpurityMeasureCalculator(encoder, useIdx);
    StepFunction<GiniImpurityMeasure>[] impurity = calculator.calculate(new DecisionTreeData(data, labels, useIdx), fs -> true, 0);
    assertEquals(1, impurity.length);
    // Check Gini calculated for the first column.
    assertArrayEquals(new double[] { Double.NEGATIVE_INFINITY, 0, 1, 2, 3 }, impurity[0].getX(), 1e-10);
    assertEquals(-3.400, impurity[0].getY()[0].impurity(), 1e-3);
    assertEquals(-5.000, impurity[0].getY()[1].impurity(), 1e-3);
    assertEquals(-4.000, impurity[0].getY()[2].impurity(), 1e-3);
    assertEquals(-3.500, impurity[0].getY()[3].impurity(), 1e-3);
    assertEquals(-3.400, impurity[0].getY()[4].impurity(), 1e-3);
}
Also used : HashMap(java.util.HashMap) StepFunction(org.apache.ignite.ml.tree.impurity.util.StepFunction) DecisionTreeData(org.apache.ignite.ml.tree.data.DecisionTreeData) Test(org.junit.Test)

Aggregations

DecisionTreeData (org.apache.ignite.ml.tree.data.DecisionTreeData)4 StepFunction (org.apache.ignite.ml.tree.impurity.util.StepFunction)3 Test (org.junit.Test)3 HashMap (java.util.HashMap)2 IgniteModel (org.apache.ignite.ml.IgniteModel)1 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)1 WeightedPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator)1 EmptyContext (org.apache.ignite.ml.dataset.primitive.context.EmptyContext)1 LearningEnvironment (org.apache.ignite.ml.environment.LearningEnvironment)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 DecisionTreeTrainer (org.apache.ignite.ml.tree.DecisionTreeTrainer)1