Search in sources :

Example 6 with BootstrappedVector

use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.

the class GiniFeatureHistogramTest method testAddVector.

/**
 */
@Test
public void testAddVector() {
    Map<Double, Integer> lblMapping = new HashMap<>();
    lblMapping.put(1.0, 0);
    lblMapping.put(2.0, 1);
    lblMapping.put(3.0, 2);
    GiniHistogram catFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature1Meta);
    GiniHistogram catFeatureSmpl2 = new GiniHistogram(1, lblMapping, feature1Meta);
    GiniHistogram contFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature2Meta);
    GiniHistogram contFeatureSmpl2 = new GiniHistogram(1, lblMapping, feature2Meta);
    for (BootstrappedVector vec : dataset) {
        catFeatureSmpl1.addElement(vec);
        catFeatureSmpl2.addElement(vec);
        contFeatureSmpl1.addElement(vec);
        contFeatureSmpl2.addElement(vec);
    }
    checkBucketIds(catFeatureSmpl1.buckets(), new Integer[] { 0, 1 });
    checkBucketIds(catFeatureSmpl2.buckets(), new Integer[] { 0, 1 });
    checkBucketIds(contFeatureSmpl1.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
    checkBucketIds(contFeatureSmpl2.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
    // categorical feature
    // for feature values 0 and 1
    checkCounters(catFeatureSmpl1.getHistForLabel(1.0), new double[] { 2, 1 });
    checkBucketIds(catFeatureSmpl1.getHistForLabel(1.0).buckets(), new Integer[] { 0, 1 });
    // for feature value 1
    checkCounters(catFeatureSmpl1.getHistForLabel(2.0), new double[] { 3 });
    checkBucketIds(catFeatureSmpl1.getHistForLabel(2.0).buckets(), new Integer[] { 1 });
    // for feature value 0
    checkCounters(catFeatureSmpl1.getHistForLabel(3.0), new double[] { 2 });
    checkBucketIds(catFeatureSmpl1.getHistForLabel(3.0).buckets(), new Integer[] { 0 });
    // for feature values 0 and 1
    checkCounters(catFeatureSmpl2.getHistForLabel(1.0), new double[] { 1, 2 });
    checkBucketIds(catFeatureSmpl2.getHistForLabel(1.0).buckets(), new Integer[] { 0, 1 });
    // for feature value 1
    checkCounters(catFeatureSmpl2.getHistForLabel(2.0), new double[] { 3 });
    checkBucketIds(catFeatureSmpl2.getHistForLabel(2.0).buckets(), new Integer[] { 1 });
    // for feature value 0
    checkCounters(catFeatureSmpl2.getHistForLabel(3.0), new double[] { 0 });
    checkBucketIds(catFeatureSmpl2.getHistForLabel(3.0).buckets(), new Integer[] { 0 });
    // continuous feature
    // for feature values 0 and 1
    checkCounters(contFeatureSmpl1.getHistForLabel(1.0), new double[] { 1, 2 });
    checkBucketIds(contFeatureSmpl1.getHistForLabel(1.0).buckets(), new Integer[] { 4, 6 });
    // for feature value 1
    checkCounters(contFeatureSmpl1.getHistForLabel(2.0), new double[] { 1, 2 });
    checkBucketIds(contFeatureSmpl1.getHistForLabel(2.0).buckets(), new Integer[] { 1, 7 });
    // for feature value 0
    checkCounters(contFeatureSmpl1.getHistForLabel(3.0), new double[] { 2 });
    checkBucketIds(contFeatureSmpl1.getHistForLabel(3.0).buckets(), new Integer[] { 8 });
    // for feature values 0 and 1
    checkCounters(contFeatureSmpl2.getHistForLabel(1.0), new double[] { 2, 1 });
    checkBucketIds(contFeatureSmpl2.getHistForLabel(1.0).buckets(), new Integer[] { 4, 6 });
    // for feature value 1
    checkCounters(contFeatureSmpl2.getHistForLabel(2.0), new double[] { 2, 1 });
    checkBucketIds(contFeatureSmpl2.getHistForLabel(2.0).buckets(), new Integer[] { 1, 7 });
    // for feature value 0
    checkCounters(contFeatureSmpl2.getHistForLabel(3.0), new double[] { 0 });
    checkBucketIds(contFeatureSmpl2.getHistForLabel(3.0).buckets(), new Integer[] { 8 });
}
Also used : HashMap(java.util.HashMap) BootstrappedVector(org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector) Test(org.junit.Test)

Aggregations

BootstrappedVector (org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector)6 Test (org.junit.Test)6 HashMap (java.util.HashMap)3 ArrayList (java.util.ArrayList)2 BucketMeta (org.apache.ignite.ml.dataset.feature.BucketMeta)2 FeatureMeta (org.apache.ignite.ml.dataset.feature.FeatureMeta)2 NodeSplit (org.apache.ignite.ml.tree.randomforest.data.NodeSplit)1