use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.
the class GiniFeatureHistogramTest method testAddVector.
/**
*/
@Test
public void testAddVector() {
Map<Double, Integer> lblMapping = new HashMap<>();
lblMapping.put(1.0, 0);
lblMapping.put(2.0, 1);
lblMapping.put(3.0, 2);
GiniHistogram catFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature1Meta);
GiniHistogram catFeatureSmpl2 = new GiniHistogram(1, lblMapping, feature1Meta);
GiniHistogram contFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature2Meta);
GiniHistogram contFeatureSmpl2 = new GiniHistogram(1, lblMapping, feature2Meta);
for (BootstrappedVector vec : dataset) {
catFeatureSmpl1.addElement(vec);
catFeatureSmpl2.addElement(vec);
contFeatureSmpl1.addElement(vec);
contFeatureSmpl2.addElement(vec);
}
checkBucketIds(catFeatureSmpl1.buckets(), new Integer[] { 0, 1 });
checkBucketIds(catFeatureSmpl2.buckets(), new Integer[] { 0, 1 });
checkBucketIds(contFeatureSmpl1.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
checkBucketIds(contFeatureSmpl2.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
// categorical feature
// for feature values 0 and 1
checkCounters(catFeatureSmpl1.getHistForLabel(1.0), new double[] { 2, 1 });
checkBucketIds(catFeatureSmpl1.getHistForLabel(1.0).buckets(), new Integer[] { 0, 1 });
// for feature value 1
checkCounters(catFeatureSmpl1.getHistForLabel(2.0), new double[] { 3 });
checkBucketIds(catFeatureSmpl1.getHistForLabel(2.0).buckets(), new Integer[] { 1 });
// for feature value 0
checkCounters(catFeatureSmpl1.getHistForLabel(3.0), new double[] { 2 });
checkBucketIds(catFeatureSmpl1.getHistForLabel(3.0).buckets(), new Integer[] { 0 });
// for feature values 0 and 1
checkCounters(catFeatureSmpl2.getHistForLabel(1.0), new double[] { 1, 2 });
checkBucketIds(catFeatureSmpl2.getHistForLabel(1.0).buckets(), new Integer[] { 0, 1 });
// for feature value 1
checkCounters(catFeatureSmpl2.getHistForLabel(2.0), new double[] { 3 });
checkBucketIds(catFeatureSmpl2.getHistForLabel(2.0).buckets(), new Integer[] { 1 });
// for feature value 0
checkCounters(catFeatureSmpl2.getHistForLabel(3.0), new double[] { 0 });
checkBucketIds(catFeatureSmpl2.getHistForLabel(3.0).buckets(), new Integer[] { 0 });
// continuous feature
// for feature values 0 and 1
checkCounters(contFeatureSmpl1.getHistForLabel(1.0), new double[] { 1, 2 });
checkBucketIds(contFeatureSmpl1.getHistForLabel(1.0).buckets(), new Integer[] { 4, 6 });
// for feature value 1
checkCounters(contFeatureSmpl1.getHistForLabel(2.0), new double[] { 1, 2 });
checkBucketIds(contFeatureSmpl1.getHistForLabel(2.0).buckets(), new Integer[] { 1, 7 });
// for feature value 0
checkCounters(contFeatureSmpl1.getHistForLabel(3.0), new double[] { 2 });
checkBucketIds(contFeatureSmpl1.getHistForLabel(3.0).buckets(), new Integer[] { 8 });
// for feature values 0 and 1
checkCounters(contFeatureSmpl2.getHistForLabel(1.0), new double[] { 2, 1 });
checkBucketIds(contFeatureSmpl2.getHistForLabel(1.0).buckets(), new Integer[] { 4, 6 });
// for feature value 1
checkCounters(contFeatureSmpl2.getHistForLabel(2.0), new double[] { 2, 1 });
checkBucketIds(contFeatureSmpl2.getHistForLabel(2.0).buckets(), new Integer[] { 1, 7 });
// for feature value 0
checkCounters(contFeatureSmpl2.getHistForLabel(3.0), new double[] { 0 });
checkBucketIds(contFeatureSmpl2.getHistForLabel(3.0).buckets(), new Integer[] { 8 });
}
Aggregations