use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class BlasTest method testGemvSparseSparseDense.
/**
* Tests 'gemv' operation for sparse matrix A, sparse vector x and dense vector y.
*/
@Test
public void testGemvSparseSparseDense() {
// y := alpha * A * x + beta * y
double alpha = 3.0;
DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 11.0 }, { 0.0, 1.0 } }, 2);
SparseVector x = sparseFromArray(new double[] { 1.0, 2.0 });
double beta = 2.0;
DenseVector y = new DenseVector(new double[] { 3.0, 4.0 });
DenseVector exp = (DenseVector) y.times(beta).plus(a.times(x).times(alpha));
Blas.gemv(alpha, a, x, beta, y);
Assert.assertEquals(exp, y);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class BlasTest method testSyrNonSquareMatrix.
/**
* Tests 'syr' operation for non-square dense matrix A.
*/
@Test(expected = NonSquareMatrixException.class)
public void testSyrNonSquareMatrix() {
double alpha = 3.0;
DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 11.0, 12.0 }, { 0.0, 1.0, 2.0 } }, 2);
Vector x = new DenseVector(new double[] { 1.0, 2.0 });
new Blas().syr(alpha, x, a);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class BlasTest method testGemvDenseDenseDense.
/**
* Tests 'gemv' operation for dense matrix A, dense vector x and dense vector y.
*/
@Test
public void testGemvDenseDenseDense() {
// y := alpha * A * x + beta * y
double alpha = 3.0;
DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 11.0 }, { 0.0, 1.0 } }, 2);
DenseVector x = new DenseVector(new double[] { 1.0, 2.0 });
double beta = 2.0;
DenseVector y = new DenseVector(new double[] { 3.0, 4.0 });
DenseVector exp = (DenseVector) y.times(beta).plus(a.times(x).times(alpha));
Blas.gemv(alpha, a, x, beta, y);
Assert.assertEquals(exp, y);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class LinearRegressionSGDTrainer method updateModel.
/**
* {@inheritDoc}
*/
@Override
protected <K, V> LinearRegressionModel updateModel(LinearRegressionModel mdl, DatasetBuilder<K, V> datasetBuilder, Preprocessor<K, V> extractor) {
assert updatesStgy != null;
IgniteFunction<Dataset<EmptyContext, SimpleLabeledDatasetData>, MLPArchitecture> archSupplier = dataset -> {
int cols = dataset.compute(data -> {
if (data.getFeatures() == null)
return null;
return data.getFeatures().length / data.getRows();
}, (a, b) -> {
if (a == null)
return b == null ? 0 : b;
if (b == null)
return a;
return b;
});
MLPArchitecture architecture = new MLPArchitecture(cols);
architecture = architecture.withAddedLayer(1, true, Activators.LINEAR);
return architecture;
};
MLPTrainer<?> trainer = new MLPTrainer<>(archSupplier, LossFunctions.MSE, updatesStgy, maxIterations, batchSize, locIterations, seed);
IgniteFunction<LabeledVector<Double>, LabeledVector<double[]>> func = lv -> new LabeledVector<>(lv.features(), new double[] { lv.label() });
PatchedPreprocessor<K, V, Double, double[]> patchedPreprocessor = new PatchedPreprocessor<>(func, extractor);
MultilayerPerceptron mlp = Optional.ofNullable(mdl).map(this::restoreMLPState).map(m -> trainer.update(m, datasetBuilder, patchedPreprocessor)).orElseGet(() -> trainer.fit(datasetBuilder, patchedPreprocessor));
double[] p = mlp.parameters().getStorage().data();
return new LinearRegressionModel(new DenseVector(Arrays.copyOf(p, p.length - 1)), p[p.length - 1]);
}
use of org.apache.ignite.ml.math.primitives.vector.impl.DenseVector in project ignite by apache.
the class Deltas method getStateVector.
/**
* @param mdl Model.
* @return vector of model weights with intercept.
*/
private Vector getStateVector(SVMLinearClassificationModel mdl) {
double intercept = mdl.intercept();
Vector weights = mdl.weights();
int stateVectorSize = weights.size() + 1;
Vector res = weights.isDense() ? new DenseVector(stateVectorSize) : new SparseVector(stateVectorSize);
res.set(0, intercept);
weights.nonZeroes().forEach(ith -> res.set(ith.index(), ith.get()));
return res;
}
Aggregations