use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class DiagonalMatrixTest method testSetGet.
/**
*/
@Test
public void testSetGet() {
verifyDiagonal(testMatrix);
final int size = MathTestConstants.STORAGE_SIZE;
for (Matrix m : new Matrix[] { new DenseLocalOnHeapMatrix(size + 1, size), new DenseLocalOnHeapMatrix(size, size + 1) }) {
fillMatrix(m);
verifyDiagonal(new DiagonalMatrix(m));
}
final double[] data = new double[size];
for (int i = 0; i < size; i++) data[i] = 1 + i;
final Matrix m = new DiagonalMatrix(new DenseLocalOnHeapVector(data));
assertEquals("Rows in matrix constructed from vector", size, m.rowSize());
assertEquals("Cols in matrix constructed from vector", size, m.columnSize());
for (int i = 0; i < size; i++) assertEquals(UNEXPECTED_VALUE + " at vector index " + i, data[i], m.get(i, i), 0d);
verifyDiagonal(m);
final Matrix m1 = new DiagonalMatrix(data);
assertEquals("Rows in matrix constructed from array", size, m1.rowSize());
assertEquals("Cols in matrix constructed from array", size, m1.columnSize());
for (int i = 0; i < size; i++) assertEquals(UNEXPECTED_VALUE + " at array index " + i, data[i], m1.get(i, i), 0d);
verifyDiagonal(m1);
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class FunctionMatrixConstructorTest method basicTest.
/**
*/
private void basicTest(int rows, int cols) {
double[][] data = new double[rows][cols];
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) data[row][col] = row * cols + row;
Matrix mReadOnly = new FunctionMatrix(rows, cols, (i, j) -> data[i][j]);
assertEquals("Rows in matrix.", rows, mReadOnly.rowSize());
assertEquals("Cols in matrix.", cols, mReadOnly.columnSize());
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) {
assertEquals("Unexpected value at " + row + "x" + col, data[row][col], mReadOnly.get(row, col), 0d);
boolean expECaught = false;
try {
mReadOnly.set(row, col, 0.0);
} catch (UnsupportedOperationException uoe) {
expECaught = true;
}
assertTrue("Expected exception wasn't thrown at " + row + "x" + col, expECaught);
}
Matrix m = new FunctionMatrix(rows, cols, (i, j) -> data[i][j], (i, j, val) -> data[i][j] = val);
assertEquals("Rows in matrix, with setter function.", rows, m.rowSize());
assertEquals("Cols in matrix, with setter function.", cols, m.columnSize());
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) {
assertEquals("Unexpected value at " + row + "x" + col, data[row][col], m.get(row, col), 0d);
m.set(row, col, -data[row][col]);
}
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) assertEquals("Unexpected value set at " + row + "x" + col, -(row * cols + row), m.get(row, col), 0d);
assertTrue("Incorrect copy for empty matrix.", m.copy().equals(m));
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class FunctionMatrixConstructorTest method basicTest.
/**
*/
@Test
public void basicTest() {
for (int rows : new int[] { 1, 2, 3 }) for (int cols : new int[] { 1, 2, 3 }) basicTest(rows, cols);
Matrix m = new FunctionMatrix(1, 1, (i, j) -> 1d);
// noinspection EqualsWithItself
assertTrue("Matrix is expected to be equal to self.", m.equals(m));
// noinspection ObjectEqualsNull
assertFalse("Matrix is expected to be not equal to null.", m.equals(null));
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationFixtures method consumeSampleMatrix.
/**
*/
void consumeSampleMatrix(BiConsumer<Matrix, String> consumer) {
for (Supplier<Iterable<Matrix>> fixtureSupplier : suppliers) {
final Iterable<Matrix> fixture = fixtureSupplier.get();
for (Matrix matrix : fixture) {
consumer.accept(matrix, fixture.toString());
matrix.destroy();
}
}
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testDeterminant.
/**
*/
@Test
public void testDeterminant() {
consumeSampleMatrix((m, desc) -> {
if (m.rowSize() != m.columnSize())
return;
if (ignore(m.getClass()))
return;
double[][] doubles = fillIntAndReturn(m);
if (m.rowSize() == 1) {
assertEquals("Unexpected value " + desc, m.determinant(), doubles[0][0], 0d);
return;
}
if (m.rowSize() == 2) {
double det = doubles[0][0] * doubles[1][1] - doubles[0][1] * doubles[1][0];
assertEquals("Unexpected value " + desc, m.determinant(), det, 0d);
return;
}
if (m.rowSize() > 512)
// IMPL NOTE if row size >= 30000 it takes unacceptably long for normal test run.
return;
Matrix diagMtx = m.like(m.rowSize(), m.columnSize());
diagMtx.assign(0);
for (int i = 0; i < m.rowSize(); i++) diagMtx.set(i, i, m.get(i, i));
double det = 1;
for (int i = 0; i < diagMtx.rowSize(); i++) det *= diagMtx.get(i, i);
try {
assertEquals("Unexpected value " + desc, det, diagMtx.determinant(), DEFAULT_DELTA);
} catch (Exception e) {
System.out.println(desc);
throw e;
}
});
}
Aggregations