use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class CholeskyDecompositionTest method basicTest.
/**
*/
private void basicTest(Matrix m) {
// This decomposition is useful when dealing with systems of linear equations of the form
// m x = b where m is a Hermitian matrix.
// For such systems Cholesky decomposition provides
// more effective method of solving compared to LU decomposition.
// Suppose we want to solve system
// m x = b for various bs. Then after we computed Cholesky decomposition, we can feed various bs
// as a matrix of the form
// (b1, b2, ..., bm)
// to the method Cholesky::solve which returns solutions in the form
// (sol1, sol2, ..., solm)
CholeskyDecomposition dec = new CholeskyDecomposition(m);
assertEquals("Unexpected value for decomposition determinant.", 4d, dec.getDeterminant(), 0d);
Matrix l = dec.getL();
Matrix lt = dec.getLT();
assertNotNull("Matrix l is expected to be not null.", l);
assertNotNull("Matrix lt is expected to be not null.", lt);
for (int row = 0; row < l.rowSize(); row++) for (int col = 0; col < l.columnSize(); col++) assertEquals("Unexpected value transposed matrix at (" + row + "," + col + ").", l.get(row, col), lt.get(col, row), 0d);
Matrix bs = new DenseLocalOnHeapMatrix(new double[][] { { 4.0, -6.0, 7.0 }, { 1.0, 1.0, 1.0 } }).transpose();
Matrix sol = dec.solve(bs);
assertNotNull("Solution matrix is expected to be not null.", sol);
assertEquals("Solution rows are not as expected.", bs.rowSize(), sol.rowSize());
assertEquals("Solution columns are not as expected.", bs.columnSize(), sol.columnSize());
for (int i = 0; i < sol.columnSize(); i++) assertNotNull("Solution matrix column is expected to be not null at index " + i, sol.viewColumn(i));
Vector b = new DenseLocalOnHeapVector(new double[] { 4.0, -6.0, 7.0 });
Vector solVec = dec.solve(b);
for (int idx = 0; idx < b.size(); idx++) assertEquals("Unexpected value solution vector at " + idx, b.get(idx), solVec.get(idx), 0d);
dec.destroy();
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class QRDecompositionTest method basicTest.
/**
*/
private void basicTest(Matrix m) {
QRDecomposition dec = new QRDecomposition(m);
assertTrue("Unexpected value for full rank in decomposition " + dec, dec.hasFullRank());
Matrix q = dec.getQ();
Matrix r = dec.getR();
assertNotNull("Matrix q is expected to be not null.", q);
assertNotNull("Matrix r is expected to be not null.", r);
Matrix qSafeCp = safeCopy(q);
Matrix expIdentity = qSafeCp.times(qSafeCp.transpose());
final double delta = 0.0001;
for (int row = 0; row < expIdentity.rowSize(); row++) for (int col = 0; col < expIdentity.columnSize(); col++) assertEquals("Unexpected identity matrix value at (" + row + "," + col + ").", row == col ? 1d : 0d, expIdentity.get(col, row), delta);
for (int row = 0; row < r.rowSize(); row++) for (int col = 0; col < row - 1; col++) assertEquals("Unexpected upper triangular matrix value at (" + row + "," + col + ").", 0d, r.get(row, col), delta);
Matrix recomposed = qSafeCp.times(r);
for (int row = 0; row < m.rowSize(); row++) for (int col = 0; col < m.columnSize(); col++) assertEquals("Unexpected recomposed matrix value at (" + row + "," + col + ").", m.get(row, col), recomposed.get(row, col), delta);
Matrix sol = dec.solve(new DenseLocalOnHeapMatrix(3, 10));
assertEquals("Unexpected rows in solution matrix.", 3, sol.rowSize());
assertEquals("Unexpected cols in solution matrix.", 10, sol.columnSize());
for (int row = 0; row < sol.rowSize(); row++) for (int col = 0; col < sol.columnSize(); col++) assertEquals("Unexpected solution matrix value at (" + row + "," + col + ").", 0d, sol.get(row, col), delta);
dec.destroy();
QRDecomposition dec1 = new QRDecomposition(new DenseLocalOnHeapMatrix(new double[][] { { 2.0d, -1.0d }, { -1.0d, 2.0d }, { 0.0d, -1.0d } }));
assertTrue("Unexpected value for full rank in decomposition " + dec1, dec1.hasFullRank());
dec1.destroy();
QRDecomposition dec2 = new QRDecomposition(new DenseLocalOnHeapMatrix(new double[][] { { 2.0d, -1.0d, 0.0d, 0.0d }, { -1.0d, 2.0d, -1.0d, 0.0d }, { 0.0d, -1.0d, 2.0d, 0.0d } }));
assertTrue("Unexpected value for full rank in decomposition " + dec2, dec2.hasFullRank());
dec2.destroy();
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class DiagonalMatrixTest method testConstant.
/**
*/
@Test
public void testConstant() {
final int size = MathTestConstants.STORAGE_SIZE;
for (double val : new double[] { -1.0, 0.0, 1.0 }) {
Matrix m = new DiagonalMatrix(size, val);
assertEquals("Rows in matrix", size, m.rowSize());
assertEquals("Cols in matrix", size, m.columnSize());
for (int i = 0; i < size; i++) assertEquals(UNEXPECTED_VALUE + " at index " + i, val, m.get(i, i), 0d);
verifyDiagonal(m, true);
}
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testViewPart.
/**
*/
@Test
public void testViewPart() {
consumeSampleMatrix((m, desc) -> {
if (ignore(m.getClass()))
return;
fillMatrix(m);
int rowOff = m.rowSize() < 3 ? 0 : 1;
int rows = m.rowSize() < 3 ? 1 : m.rowSize() - 2;
int colOff = m.columnSize() < 3 ? 0 : 1;
int cols = m.columnSize() < 3 ? 1 : m.columnSize() - 2;
Matrix view1 = m.viewPart(rowOff, rows, colOff, cols);
Matrix view2 = m.viewPart(new int[] { rowOff, colOff }, new int[] { rows, cols });
String details = desc + " view [" + rowOff + ", " + rows + ", " + colOff + ", " + cols + "]";
for (int i = 0; i < rows; i++) for (int j = 0; j < cols; j++) {
assertEquals("Unexpected view1 value for " + details + " at (" + i + "," + j + ")", m.get(i + rowOff, j + colOff), view1.get(i, j), 0d);
assertEquals("Unexpected view2 value for " + details + " at (" + i + "," + j + ")", m.get(i + rowOff, j + colOff), view2.get(i, j), 0d);
}
});
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testMapMatrix.
/**
*/
@Test
public void testMapMatrix() {
consumeSampleMatrix((m, desc) -> {
if (ignore(m.getClass()))
return;
double[][] doubles = fillAndReturn(m);
testMapMatrixWrongCardinality(m, desc);
Matrix cp = m.copy();
m.map(cp, (m1, m2) -> m1 + m2);
for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", m.get(i, j), doubles[i][j] * 2, 0d);
});
}
Aggregations