use of org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix in project ignite by apache.
the class OLSMultipleLinearRegressionTest method testWampler2.
/**
* This is a test based on the Wampler2 data set
* http://www.itl.nist.gov/div898/strd/lls/data/Wampler2.shtml
*/
@Test
public void testWampler2() {
double[] data = new double[] { 1.00000, 0, 1.11111, 1, 1.24992, 2, 1.42753, 3, 1.65984, 4, 1.96875, 5, 2.38336, 6, 2.94117, 7, 3.68928, 8, 4.68559, 9, 6.00000, 10, 7.71561, 11, 9.92992, 12, 12.75603, 13, 16.32384, 14, 20.78125, 15, 26.29536, 16, 33.05367, 17, 41.26528, 18, 51.16209, 19, 63.00000, 20 };
OLSMultipleLinearRegression mdl = new OLSMultipleLinearRegression();
final int nvars = 5;
final int nobs = 21;
double[] tmp = new double[(nvars + 1) * nobs];
int off = 0;
int off2 = 0;
for (int i = 0; i < nobs; i++) {
tmp[off2] = data[off];
tmp[off2 + 1] = data[off + 1];
tmp[off2 + 2] = tmp[off2 + 1] * tmp[off2 + 1];
tmp[off2 + 3] = tmp[off2 + 1] * tmp[off2 + 2];
tmp[off2 + 4] = tmp[off2 + 1] * tmp[off2 + 3];
tmp[off2 + 5] = tmp[off2 + 1] * tmp[off2 + 4];
off2 += (nvars + 1);
off += 2;
}
mdl.newSampleData(tmp, nobs, nvars, new DenseLocalOnHeapMatrix());
double[] betaHat = mdl.estimateRegressionParameters();
TestUtils.assertEquals(betaHat, new double[] { 1.0, 1.0e-1, 1.0e-2, 1.0e-3, 1.0e-4, 1.0e-5 }, 1E-8);
double[] se = mdl.estimateRegressionParametersStandardErrors();
TestUtils.assertEquals(se, new double[] { 0.0, 0.0, 0.0, 0.0, 0.0, 0.0 }, 1E-8);
TestUtils.assertEquals(1.0, mdl.calculateRSquared(), 1.0e-10);
TestUtils.assertEquals(0, mdl.estimateErrorVariance(), 1.0e-7);
TestUtils.assertEquals(0.00, mdl.calculateResidualSumOfSquares(), 1.0e-6);
}
use of org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix in project ignite by apache.
the class AbstractMultipleLinearRegressionTest method testNewSampleInvalidData.
/** */
@Test(expected = MathIllegalArgumentException.class)
public void testNewSampleInvalidData() {
double[] data = new double[] { 1, 2, 3, 4 };
createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix());
}
use of org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix in project ignite by apache.
the class AbstractMultipleLinearRegressionTest method testNewSampleInsufficientData.
/** */
@Test(expected = MathIllegalArgumentException.class)
public void testNewSampleInsufficientData() {
double[] data = new double[] { 1, 2, 3, 4 };
createRegression().newSampleData(data, 1, 3, new DenseLocalOnHeapMatrix());
}
use of org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix in project ignite by apache.
the class AbstractMultipleLinearRegressionTest method testNewSampleNullData.
/** */
@Test(expected = NullArgumentException.class)
public void testNewSampleNullData() {
double[] data = null;
createRegression().newSampleData(data, 2, 3, new DenseLocalOnHeapMatrix());
}
use of org.apache.ignite.ml.math.impls.matrix.DenseLocalOnHeapMatrix in project ignite by apache.
the class AbstractMultipleLinearRegressionTest method testNewSample.
/**
* Verifies that newSampleData methods consistently insert unitary columns
* in design matrix. Confirms the fix for MATH-411.
*/
@Test
public void testNewSample() {
double[] design = new double[] { 1, 19, 22, 33, 2, 20, 30, 40, 3, 25, 35, 45, 4, 27, 37, 47 };
double[] y = new double[] { 1, 2, 3, 4 };
double[][] x = new double[][] { { 19, 22, 33 }, { 20, 30, 40 }, { 25, 35, 45 }, { 27, 37, 47 } };
AbstractMultipleLinearRegression regression = createRegression();
regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
Matrix flatX = regression.getX().copy();
Vector flatY = regression.getY().copy();
regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
regression.newYSampleData(new DenseLocalOnHeapVector(y));
Assert.assertEquals(flatX, regression.getX());
Assert.assertEquals(flatY, regression.getY());
// No intercept
regression.setNoIntercept(true);
regression.newSampleData(design, 4, 3, new DenseLocalOnHeapMatrix());
flatX = regression.getX().copy();
flatY = regression.getY().copy();
regression.newXSampleData(new DenseLocalOnHeapMatrix(x));
regression.newYSampleData(new DenseLocalOnHeapVector(y));
Assert.assertEquals(flatX, regression.getX());
Assert.assertEquals(flatY, regression.getY());
}
Aggregations