Search in sources :

Example 1 with ConvergenceCheckerStubFactory

use of org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory in project ignite by apache.

the class GDBTrainerTest method testClassifier.

/**
 */
private void testClassifier(BiFunction<GDBTrainer, Map<Integer, double[]>, IgniteModel<Vector, Double>> fitter) {
    int sampleSize = 100;
    double[] xs = new double[sampleSize];
    double[] ys = new double[sampleSize];
    for (int i = 0; i < sampleSize; i++) {
        xs[i] = i;
        ys[i] = ((int) (xs[i] / 10.0) % 2) == 0 ? -1.0 : 1.0;
    }
    Map<Integer, double[]> learningSample = new HashMap<>();
    for (int i = 0; i < sampleSize; i++) learningSample.put(i, new double[] { xs[i], ys[i] });
    GDBTrainer trainer = new GDBBinaryClassifierOnTreesTrainer(0.3, 500, 3, 0.0).withUsingIdx(true).withCheckConvergenceStgyFactory(new MeanAbsValueConvergenceCheckerFactory(0.3));
    IgniteModel<Vector, Double> mdl = fitter.apply(trainer, learningSample);
    int errorsCnt = 0;
    for (int j = 0; j < sampleSize; j++) {
        double x = xs[j];
        double y = ys[j];
        double p = mdl.predict(VectorUtils.of(x));
        if (p != y)
            errorsCnt++;
    }
    assertEquals(0, errorsCnt);
    assertTrue(mdl instanceof ModelsComposition);
    ModelsComposition composition = (ModelsComposition) mdl;
    composition.getModels().forEach(m -> assertTrue(m instanceof DecisionTreeModel));
    assertTrue(composition.getModels().size() < 500);
    assertTrue(composition.getPredictionsAggregator() instanceof WeightedPredictionsAggregator);
    trainer = trainer.withCheckConvergenceStgyFactory(new ConvergenceCheckerStubFactory());
    assertEquals(500, ((ModelsComposition) fitter.apply(trainer, learningSample)).getModels().size());
}
Also used : HashMap(java.util.HashMap) DecisionTreeModel(org.apache.ignite.ml.tree.DecisionTreeModel) WeightedPredictionsAggregator(org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator) ModelsComposition(org.apache.ignite.ml.composition.ModelsComposition) GDBBinaryClassifierOnTreesTrainer(org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer) MeanAbsValueConvergenceCheckerFactory(org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory) ConvergenceCheckerStubFactory(org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory) Vector(org.apache.ignite.ml.math.primitives.vector.Vector)

Aggregations

HashMap (java.util.HashMap)1 ModelsComposition (org.apache.ignite.ml.composition.ModelsComposition)1 MeanAbsValueConvergenceCheckerFactory (org.apache.ignite.ml.composition.boosting.convergence.mean.MeanAbsValueConvergenceCheckerFactory)1 ConvergenceCheckerStubFactory (org.apache.ignite.ml.composition.boosting.convergence.simple.ConvergenceCheckerStubFactory)1 WeightedPredictionsAggregator (org.apache.ignite.ml.composition.predictionsaggregator.WeightedPredictionsAggregator)1 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)1 DecisionTreeModel (org.apache.ignite.ml.tree.DecisionTreeModel)1 GDBBinaryClassifierOnTreesTrainer (org.apache.ignite.ml.tree.boosting.GDBBinaryClassifierOnTreesTrainer)1