use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.
the class GiniFeatureHistogramTest method testOfSums.
/**
*/
@Test
public void testOfSums() {
int sampleId = 0;
BucketMeta bucketMeta1 = new BucketMeta(new FeatureMeta("", 0, false));
bucketMeta1.setMinVal(0.);
bucketMeta1.setBucketSize(0.1);
BucketMeta bucketMeta2 = new BucketMeta(new FeatureMeta("", 1, true));
GiniHistogram forAllHist1 = new GiniHistogram(sampleId, lblMapping, bucketMeta1);
GiniHistogram forAllHist2 = new GiniHistogram(sampleId, lblMapping, bucketMeta2);
List<GiniHistogram> partitions1 = new ArrayList<>();
List<GiniHistogram> partitions2 = new ArrayList<>();
int cntOfPartitions = rnd.nextInt(1000);
for (int i = 0; i < cntOfPartitions; i++) {
partitions1.add(new GiniHistogram(sampleId, lblMapping, bucketMeta1));
partitions2.add(new GiniHistogram(sampleId, lblMapping, bucketMeta2));
}
int datasetSize = rnd.nextInt(10000);
for (int i = 0; i < datasetSize; i++) {
BootstrappedVector vec = randomVector(true);
vec.features().set(1, (vec.features().get(1) * 100) % 100);
forAllHist1.addElement(vec);
forAllHist2.addElement(vec);
int partId = rnd.nextInt(cntOfPartitions);
partitions1.get(partId).addElement(vec);
partitions2.get(partId).addElement(vec);
}
checkSums(forAllHist1, partitions1);
checkSums(forAllHist2, partitions2);
GiniHistogram emptyHist1 = new GiniHistogram(sampleId, lblMapping, bucketMeta1);
GiniHistogram emptyHist2 = new GiniHistogram(sampleId, lblMapping, bucketMeta2);
assertTrue(forAllHist1.isEqualTo(forAllHist1.plus(emptyHist1)));
assertTrue(forAllHist2.isEqualTo(forAllHist2.plus(emptyHist2)));
assertTrue(forAllHist1.isEqualTo(emptyHist1.plus(forAllHist1)));
assertTrue(forAllHist2.isEqualTo(emptyHist2.plus(forAllHist2)));
}
use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.
the class MSEHistogramTest method testOfSums.
/**
*/
@Test
public void testOfSums() {
int sampleId = 0;
BucketMeta bucketMeta1 = new BucketMeta(new FeatureMeta("", 0, false));
bucketMeta1.setMinVal(0.);
bucketMeta1.setBucketSize(0.1);
BucketMeta bucketMeta2 = new BucketMeta(new FeatureMeta("", 1, true));
MSEHistogram forAllHist1 = new MSEHistogram(sampleId, bucketMeta1);
MSEHistogram forAllHist2 = new MSEHistogram(sampleId, bucketMeta2);
List<MSEHistogram> partitions1 = new ArrayList<>();
List<MSEHistogram> partitions2 = new ArrayList<>();
int cntOfPartitions = rnd.nextInt(100) + 1;
for (int i = 0; i < cntOfPartitions; i++) {
partitions1.add(new MSEHistogram(sampleId, bucketMeta1));
partitions2.add(new MSEHistogram(sampleId, bucketMeta2));
}
int datasetSize = rnd.nextInt(1000) + 1;
for (int i = 0; i < datasetSize; i++) {
BootstrappedVector vec = randomVector(false);
vec.features().set(1, (vec.features().get(1) * 100) % 100);
forAllHist1.addElement(vec);
forAllHist2.addElement(vec);
int partId = rnd.nextInt(cntOfPartitions);
partitions1.get(partId).addElement(vec);
partitions2.get(partId).addElement(vec);
}
checkSums(forAllHist1, partitions1);
checkSums(forAllHist2, partitions2);
MSEHistogram emptyHist1 = new MSEHistogram(sampleId, bucketMeta1);
MSEHistogram emptyHist2 = new MSEHistogram(sampleId, bucketMeta2);
assertTrue(forAllHist1.isEqualTo(forAllHist1.plus(emptyHist1)));
assertTrue(forAllHist2.isEqualTo(forAllHist2.plus(emptyHist2)));
assertTrue(forAllHist1.isEqualTo(emptyHist1.plus(forAllHist1)));
assertTrue(forAllHist2.isEqualTo(emptyHist2.plus(forAllHist2)));
}
use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.
the class MSEHistogramTest method testAdd.
/**
*/
@Test
public void testAdd() {
MSEHistogram catHist1 = new MSEHistogram(0, feature1Meta);
MSEHistogram contHist1 = new MSEHistogram(0, feature2Meta);
MSEHistogram catHist2 = new MSEHistogram(1, feature1Meta);
MSEHistogram contHist2 = new MSEHistogram(1, feature2Meta);
for (BootstrappedVector vec : dataset) {
catHist1.addElement(vec);
catHist2.addElement(vec);
contHist1.addElement(vec);
contHist2.addElement(vec);
}
checkBucketIds(catHist1.buckets(), new Integer[] { 0, 1 });
checkBucketIds(catHist2.buckets(), new Integer[] { 0, 1 });
checkBucketIds(contHist1.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
checkBucketIds(contHist2.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
// counters
checkCounters(catHist1.getCounters(), new double[] { 4, 4 });
checkCounters(catHist2.getCounters(), new double[] { 1, 5 });
checkCounters(contHist1.getCounters(), new double[] { 1, 1, 2, 2, 2 });
checkCounters(contHist2.getCounters(), new double[] { 2, 2, 1, 1, 0 });
// ys
checkCounters(catHist1.getSumOfLabels(), new double[] { 2 * 4 + 2 * 3, 5 + 1 + 2 * 2 });
checkCounters(catHist2.getSumOfLabels(), new double[] { 4, 2 * 5 + 2 * 1 + 2 });
checkCounters(contHist1.getSumOfLabels(), new double[] { 5 * 1, 1 * 1, 4 * 2, 2 * 2, 3 * 2 });
checkCounters(contHist2.getSumOfLabels(), new double[] { 2 * 5, 2 * 1, 1 * 4, 2 * 1, 0 * 3 });
// y2s
checkCounters(catHist1.getSumOfSquaredLabels(), new double[] { 2 * 4 * 4 + 2 * 3 * 3, 5 * 5 + 1 + 2 * 2 * 2 });
checkCounters(catHist2.getSumOfSquaredLabels(), new double[] { 4 * 4, 2 * 5 * 5 + 2 * 1 * 1 + 2 * 2 });
checkCounters(contHist1.getSumOfSquaredLabels(), new double[] { 1 * 5 * 5, 1 * 1 * 1, 2 * 4 * 4, 2 * 2 * 2, 2 * 3 * 3 });
checkCounters(contHist2.getSumOfSquaredLabels(), new double[] { 2 * 5 * 5, 2 * 1 * 1, 1 * 4 * 4, 1 * 2 * 2, 0 * 3 * 3 });
}
use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.
the class GiniFeatureHistogramTest method testJoin.
/**
*/
@Test
public void testJoin() {
Map<Double, Integer> lblMapping = new HashMap<>();
lblMapping.put(1.0, 0);
lblMapping.put(2.0, 1);
lblMapping.put(3.0, 2);
GiniHistogram catFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature1Meta);
GiniHistogram catFeatureSmpl2 = new GiniHistogram(0, lblMapping, feature1Meta);
GiniHistogram contFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature2Meta);
GiniHistogram contFeatureSmpl2 = new GiniHistogram(0, lblMapping, feature2Meta);
for (BootstrappedVector vec : dataset) {
catFeatureSmpl1.addElement(vec);
contFeatureSmpl1.addElement(vec);
}
for (BootstrappedVector vec : toSplitDataset) {
catFeatureSmpl2.addElement(vec);
contFeatureSmpl2.addElement(vec);
}
GiniHistogram res1 = catFeatureSmpl1.plus(catFeatureSmpl2);
GiniHistogram res2 = contFeatureSmpl1.plus(contFeatureSmpl2);
checkBucketIds(res1.buckets(), new Integer[] { 0, 1, 2 });
checkBucketIds(res2.buckets(), new Integer[] { 1, 4, 6, 7, 8 });
// categorical feature
// for feature values 0 and 1
checkCounters(res1.getHistForLabel(1.0), new double[] { 3, 2, 6 });
checkBucketIds(res1.getHistForLabel(1.0).buckets(), new Integer[] { 0, 1, 2 });
// for feature value 1
checkCounters(res1.getHistForLabel(2.0), new double[] { 4, 6 });
checkBucketIds(res1.getHistForLabel(2.0).buckets(), new Integer[] { 0, 1 });
// for feature value 0
checkCounters(res1.getHistForLabel(3.0), new double[] { 2 });
checkBucketIds(res1.getHistForLabel(3.0).buckets(), new Integer[] { 0 });
// continuous feature
// for feature values 0 and 1
checkCounters(res2.getHistForLabel(1.0), new double[] { 1, 1, 8, 1 });
checkBucketIds(res2.getHistForLabel(1.0).buckets(), new Integer[] { 1, 4, 6, 8 });
// for feature value 1
checkCounters(res2.getHistForLabel(2.0), new double[] { 1, 4, 0, 5 });
checkBucketIds(res2.getHistForLabel(2.0).buckets(), new Integer[] { 1, 4, 6, 7 });
// for feature value 0
checkCounters(res2.getHistForLabel(3.0), new double[] { 2 });
checkBucketIds(res2.getHistForLabel(3.0).buckets(), new Integer[] { 8 });
}
use of org.apache.ignite.ml.dataset.impl.bootstrapping.BootstrappedVector in project ignite by apache.
the class GiniFeatureHistogramTest method testSplit.
/**
*/
@Test
public void testSplit() {
Map<Double, Integer> lblMapping = new HashMap<>();
lblMapping.put(1.0, 0);
lblMapping.put(2.0, 1);
GiniHistogram catFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature1Meta);
GiniHistogram contFeatureSmpl1 = new GiniHistogram(0, lblMapping, feature2Meta);
GiniHistogram emptyHist = new GiniHistogram(0, lblMapping, feature3Meta);
GiniHistogram catFeatureSmpl2 = new GiniHistogram(0, lblMapping, feature3Meta);
feature2Meta.setMinVal(-5);
feature2Meta.setBucketSize(1);
for (BootstrappedVector vec : toSplitDataset) {
catFeatureSmpl1.addElement(vec);
contFeatureSmpl1.addElement(vec);
catFeatureSmpl2.addElement(vec);
}
NodeSplit catSplit = catFeatureSmpl1.findBestSplit().get();
NodeSplit contSplit = contFeatureSmpl1.findBestSplit().get();
assertEquals(1.0, catSplit.getVal(), 0.01);
assertEquals(-0.5, contSplit.getVal(), 0.01);
assertFalse(emptyHist.findBestSplit().isPresent());
assertFalse(catFeatureSmpl2.findBestSplit().isPresent());
}
Aggregations