use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class BlasTest method testGemvDenseDenseDense.
/**
* Tests 'gemv' operation for dense matrix A, dense vector x and dense vector y.
*/
@Test
public void testGemvDenseDenseDense() {
// y := alpha * A * x + beta * y
double alpha = 3.0;
DenseLocalOnHeapMatrix a = new DenseLocalOnHeapMatrix(new double[][] { { 10.0, 11.0 }, { 0.0, 1.0 } }, 2);
DenseLocalOnHeapVector x = new DenseLocalOnHeapVector(new double[] { 1.0, 2.0 });
double beta = 2.0;
DenseLocalOnHeapVector y = new DenseLocalOnHeapVector(new double[] { 3.0, 4.0 });
DenseLocalOnHeapVector exp = (DenseLocalOnHeapVector) y.times(beta).plus(a.times(x).times(alpha));
Blas.gemv(alpha, a, x, beta, y);
Assert.assertEquals(exp, y);
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class BlasTest method testDot.
/**
* Test 'dot' operation.
*/
@Test
public void testDot() {
DenseLocalOnHeapVector v1 = new DenseLocalOnHeapVector(new double[] { 1.0, 1.0 });
DenseLocalOnHeapVector v2 = new DenseLocalOnHeapVector(new double[] { 2.0, 2.0 });
Assert.assertEquals(Blas.dot(v1, v2), v1.dot(v2), 0.0);
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class DiagonalMatrixTest method testSetGet.
/**
*/
@Test
public void testSetGet() {
verifyDiagonal(testMatrix);
final int size = MathTestConstants.STORAGE_SIZE;
for (Matrix m : new Matrix[] { new DenseLocalOnHeapMatrix(size + 1, size), new DenseLocalOnHeapMatrix(size, size + 1) }) {
fillMatrix(m);
verifyDiagonal(new DiagonalMatrix(m));
}
final double[] data = new double[size];
for (int i = 0; i < size; i++) data[i] = 1 + i;
final Matrix m = new DiagonalMatrix(new DenseLocalOnHeapVector(data));
assertEquals("Rows in matrix constructed from vector", size, m.rowSize());
assertEquals("Cols in matrix constructed from vector", size, m.columnSize());
for (int i = 0; i < size; i++) assertEquals(UNEXPECTED_VALUE + " at vector index " + i, data[i], m.get(i, i), 0d);
verifyDiagonal(m);
final Matrix m1 = new DiagonalMatrix(data);
assertEquals("Rows in matrix constructed from array", size, m1.rowSize());
assertEquals("Cols in matrix constructed from array", size, m1.columnSize());
for (int i = 0; i < size; i++) assertEquals(UNEXPECTED_VALUE + " at array index " + i, data[i], m1.get(i, i), 0d);
verifyDiagonal(m1);
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class MatrixImplementationsTest method testTimesVector.
/**
*/
@Test
public void testTimesVector() {
consumeSampleMatrix((m, desc) -> {
if (ignore(m.getClass()))
return;
if (m instanceof DenseLocalOffHeapMatrix)
// TODO: IGNITE-5535, waiting offheap support.
return;
double[][] data = fillAndReturn(m);
double[] arr = fillArray(m.columnSize());
Vector times = m.times(new DenseLocalOnHeapVector(arr));
assertEquals("Unexpected vector size for " + desc, times.size(), m.rowSize());
for (int i = 0; i < m.rowSize(); i++) {
double exp = 0.0;
for (int j = 0; j < m.columnSize(); j++) exp += arr[j] * data[i][j];
assertEquals("Unexpected value for " + desc + " at " + i, times.get(i), exp, DEFAULT_DELTA);
}
testInvalidCardinality(() -> m.times(new DenseLocalOnHeapVector(m.columnSize() + 1)), desc);
});
}
use of org.apache.ignite.ml.math.impls.vector.DenseLocalOnHeapVector in project ignite by apache.
the class BlasTest method testGemvSparseDenseDense.
/**
* Tests 'gemv' operation for dense matrix A, dense vector x and dense vector y.
*/
@Test
public void testGemvSparseDenseDense() {
// y := alpha * A * x + beta * y
double alpha = 3.0;
SparseLocalOnHeapMatrix a = (SparseLocalOnHeapMatrix) new SparseLocalOnHeapMatrix(2, 2).assign(new double[][] { { 10.0, 11.0 }, { 0.0, 1.0 } });
DenseLocalOnHeapVector x = new DenseLocalOnHeapVector(new double[] { 1.0, 2.0 });
double beta = 2.0;
DenseLocalOnHeapVector y = new DenseLocalOnHeapVector(new double[] { 3.0, 4.0 });
DenseLocalOnHeapVector exp = (DenseLocalOnHeapVector) y.times(beta).plus(a.times(x).times(alpha));
Blas.gemv(alpha, a, x, beta, y);
Assert.assertEquals(exp, y);
}
Aggregations