use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testDivide.
/**
*/
@Test
public void testDivide() {
consumeSampleMatrix((m, desc) -> {
if (ignore(m.getClass()))
return;
double[][] data = fillAndReturn(m);
double divVal = Math.random();
Matrix divide = m.divide(divVal);
for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j] / divVal, divide.get(i, j), 0d);
});
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testPlusMatrix.
/**
*/
@Test
public void testPlusMatrix() {
consumeSampleMatrix((m, desc) -> {
if (ignore(m.getClass()))
return;
double[][] data = fillAndReturn(m);
Matrix plus = m.plus(m);
for (int i = 0; i < m.rowSize(); i++) for (int j = 0; j < m.columnSize(); j++) assertEquals("Unexpected value for " + desc + " at (" + i + "," + j + ")", data[i][j] * 2.0, plus.get(i, j), 0d);
});
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testInverse.
/**
*/
@Test
public void testInverse() {
consumeSampleMatrix((m, desc) -> {
if (m.rowSize() != m.columnSize())
return;
if (ignore(m.getClass()))
return;
if (m.rowSize() > 256)
// IMPL NOTE this is for quicker test run.
return;
fillNonSingularMatrix(m);
assertTrue("Unexpected zero determinant " + desc, Math.abs(m.determinant()) > 0d);
Matrix inverse = m.inverse();
Matrix mult = m.times(inverse);
final double delta = 0.001d;
assertEquals("Unexpected determinant " + desc, 1d, mult.determinant(), delta);
assertEquals("Unexpected top left value " + desc, 1d, mult.get(0, 0), delta);
if (m.rowSize() == 1)
return;
assertEquals("Unexpected center value " + desc, 1d, mult.get(m.rowSize() / 2, m.rowSize() / 2), delta);
assertEquals("Unexpected bottom right value " + desc, 1d, mult.get(m.rowSize() - 1, m.rowSize() - 1), delta);
assertEquals("Unexpected top right value " + desc, 0d, mult.get(0, m.rowSize() - 1), delta);
assertEquals("Unexpected bottom left value " + desc, 0d, mult.get(m.rowSize() - 1, 0), delta);
});
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixImplementationsTest method testSwapColumns.
/**
*/
@Test
public void testSwapColumns() {
consumeSampleMatrix((m, desc) -> {
if (readOnly(m))
return;
double[][] doubles = fillAndReturn(m);
final int swap_i = m.columnSize() == 1 ? 0 : 1;
final int swap_j = 0;
Matrix swap = m.swapColumns(swap_i, swap_j);
for (int row = 0; row < m.rowSize(); row++) {
assertEquals("Unexpected value for " + desc + " at row " + row + ", swap_i " + swap_i, swap.get(row, swap_i), doubles[row][swap_j], 0d);
assertEquals("Unexpected value for " + desc + " at row " + row + ", swap_j " + swap_j, swap.get(row, swap_j), doubles[row][swap_i], 0d);
}
testInvalidColIndex(() -> m.swapColumns(-1, 0), desc + " negative first swap index");
testInvalidColIndex(() -> m.swapColumns(0, -1), desc + " negative second swap index");
testInvalidColIndex(() -> m.swapColumns(m.columnSize(), 0), desc + " too large first swap index");
testInvalidColIndex(() -> m.swapColumns(0, m.columnSize()), desc + " too large second swap index");
});
}
use of org.apache.ignite.ml.math.Matrix in project ignite by apache.
the class MatrixViewConstructorTest method basicTest.
/**
*/
private void basicTest(Matrix parent, int rowOff, int colOff, int rows, int cols) {
for (int row = 0; row < parent.rowSize(); row++) for (int col = 0; col < parent.columnSize(); col++) parent.set(row, col, row * parent.columnSize() + col + 1);
Matrix view = new MatrixView(parent, rowOff, colOff, rows, cols);
assertEquals("Rows in view.", rows, view.rowSize());
assertEquals("Cols in view.", cols, view.columnSize());
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) assertEquals("Unexpected value at " + row + "x" + col, parent.get(row + rowOff, col + colOff), view.get(row, col), 0d);
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) view.set(row, col, 0d);
for (int row = 0; row < rows; row++) for (int col = 0; col < cols; col++) assertEquals("Unexpected value set at " + row + "x" + col, 0d, parent.get(row + rowOff, col + colOff), 0d);
}
Aggregations