use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class OLSMultipleLinearRegressionTest method cannotAddSampleDataWithSizeMismatch.
/** */
@Test(expected = MathIllegalArgumentException.class)
public void cannotAddSampleDataWithSizeMismatch() {
double[] y = new double[] { 1.0, 2.0 };
double[][] x = new double[1][];
x[0] = new double[] { 1.0, 0 };
createRegression().newSampleData(new DenseLocalOnHeapVector(y), new DenseLocalOnHeapMatrix(x));
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class DatasetWithObviousStructure method getPoint.
/**
*/
private Vector getPoint(Integer cnt) {
Vector pnt = new DenseLocalOnHeapVector(2).assign(centers.get(cnt));
// Perturbate point on random value.
pnt.map(val -> val + rnd.nextDouble() * squareSideLen / 100);
return pnt;
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class ColumnDecisionTreeTrainerTest method testCacheContGini.
/**
* Test {@link ColumnDecisionTreeTrainerTest} for continuous data with Gini impurity.
*/
public void testCacheContGini() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
int totalPts = 1 << 10;
int featCnt = 12;
HashMap<Integer, Integer> catsInfo = new HashMap<>();
Random rnd = new Random(12349L);
SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>(featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd).split(0, 0, -10.0).split(1, 0, 0.0).split(1, 1, 2.0).split(3, 7, 50.0);
testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.GINI.apply(ignite), RegionCalculators.GINI, RegionCalculators.MEAN, rnd);
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class ColumnDecisionTreeTrainerTest method testByGen.
/**
*/
private <D extends ContinuousRegionInfo> void testByGen(int totalPts, HashMap<Integer, Integer> catsInfo, SplitDataGenerator<DenseLocalOnHeapVector> gen, IgniteFunction<ColumnDecisionTreeTrainerInput, ? extends ContinuousSplitCalculator<D>> calc, IgniteFunction<ColumnDecisionTreeTrainerInput, IgniteFunction<DoubleStream, Double>> catImpCalc, IgniteFunction<DoubleStream, Double> regCalc, Random rnd) {
List<IgniteBiTuple<Integer, DenseLocalOnHeapVector>> lst = gen.points(totalPts, (i, rn) -> i).collect(Collectors.toList());
int featCnt = gen.featuresCnt();
Collections.shuffle(lst, rnd);
SparseDistributedMatrix m = new SparseDistributedMatrix(totalPts, featCnt + 1, StorageConstants.COLUMN_STORAGE_MODE, StorageConstants.RANDOM_ACCESS_MODE);
Map<Integer, List<LabeledVectorDouble>> byRegion = new HashMap<>();
int i = 0;
for (IgniteBiTuple<Integer, DenseLocalOnHeapVector> bt : lst) {
byRegion.putIfAbsent(bt.get1(), new LinkedList<>());
byRegion.get(bt.get1()).add(asLabeledVector(bt.get2().getStorage().data()));
m.setRow(i, bt.get2().getStorage().data());
i++;
}
ColumnDecisionTreeTrainer<D> trainer = new ColumnDecisionTreeTrainer<>(3, calc, catImpCalc, regCalc, ignite);
DecisionTreeModel mdl = trainer.train(new MatrixColumnDecisionTreeTrainerInput(m, catsInfo));
byRegion.keySet().forEach(k -> {
LabeledVectorDouble sp = byRegion.get(k).get(0);
Tracer.showAscii(sp.features());
X.println("Actual and predicted vectors [act=" + sp.label() + " " + ", pred=" + mdl.apply(sp.features()) + "]");
assert mdl.apply(sp.features()) == sp.doubleLabel();
});
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class ColumnDecisionTreeTrainerTest method testCacheCat.
/**
* Test {@link ColumnDecisionTreeTrainerTest} for categorical data with Variance impurity.
*/
public void testCacheCat() {
IgniteUtils.setCurrentIgniteName(ignite.configuration().getIgniteInstanceName());
int totalPts = 1 << 10;
int featCnt = 12;
HashMap<Integer, Integer> catsInfo = new HashMap<>();
catsInfo.put(5, 7);
Random rnd = new Random(12349L);
SplitDataGenerator<DenseLocalOnHeapVector> gen = new SplitDataGenerator<>(featCnt, catsInfo, () -> new DenseLocalOnHeapVector(featCnt + 1), rnd).split(0, 5, new int[] { 0, 2, 5 });
testByGen(totalPts, catsInfo, gen, ContinuousSplitCalculators.VARIANCE, RegionCalculators.VARIANCE, RegionCalculators.MEAN, rnd);
}
Aggregations