Search in sources :

Example 6 with DenseMatrix

use of org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix in project ignite by apache.

the class StackingTest method testSimpleStack.

/**
 * Tests simple stack training.
 */
@Test
public void testSimpleStack() {
    StackedDatasetTrainer<Vector, Vector, Double, LinearRegressionModel, Double> trainer = new StackedDatasetTrainer<>();
    UpdatesStrategy<SmoothParametrized, SimpleGDParameterUpdate> updatesStgy = new UpdatesStrategy<>(new SimpleGDUpdateCalculator(0.2), SimpleGDParameterUpdate.SUM_LOCAL, SimpleGDParameterUpdate.AVG);
    MLPArchitecture arch = new MLPArchitecture(2).withAddedLayer(10, true, Activators.RELU).withAddedLayer(1, false, Activators.SIGMOID);
    MLPTrainer<SimpleGDParameterUpdate> trainer1 = new MLPTrainer<>(arch, LossFunctions.MSE, updatesStgy, 3000, 10, 50, 123L);
    // Convert model trainer to produce Vector -> Vector model
    DatasetTrainer<AdaptableDatasetModel<Vector, Vector, Matrix, Matrix, MultilayerPerceptron>, Double> mlpTrainer = AdaptableDatasetTrainer.of(trainer1).beforeTrainedModel((Vector v) -> new DenseMatrix(v.asArray(), 1)).afterTrainedModel((Matrix mtx) -> mtx.getRow(0)).withConvertedLabels(VectorUtils::num2Arr);
    final double factor = 3;
    StackedModel<Vector, Vector, Double, LinearRegressionModel> mdl = trainer.withAggregatorTrainer(new LinearRegressionLSQRTrainer().withConvertedLabels(x -> x * factor)).addTrainer(mlpTrainer).withAggregatorInputMerger(VectorUtils::concat).withSubmodelOutput2VectorConverter(IgniteFunction.identity()).withVector2SubmodelInputConverter(IgniteFunction.identity()).withOriginalFeaturesKept(IgniteFunction.identity()).withEnvironmentBuilder(TestUtils.testEnvBuilder()).fit(getCacheMock(xor), parts, new DoubleArrayVectorizer<Integer>().labeled(Vectorizer.LabelCoordinate.LAST));
    assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(0.0, 0.0)), 0.3);
    assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(0.0, 1.0)), 0.3);
    assertEquals(1.0 * factor, mdl.predict(VectorUtils.of(1.0, 0.0)), 0.3);
    assertEquals(0.0 * factor, mdl.predict(VectorUtils.of(1.0, 1.0)), 0.3);
}
Also used : VectorUtils(org.apache.ignite.ml.math.primitives.vector.VectorUtils) DoubleArrayVectorizer(org.apache.ignite.ml.dataset.feature.extractor.impl.DoubleArrayVectorizer) LinearRegressionModel(org.apache.ignite.ml.regressions.linear.LinearRegressionModel) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) MLPTrainer(org.apache.ignite.ml.nn.MLPTrainer) SimpleGDParameterUpdate(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) LinearRegressionLSQRTrainer(org.apache.ignite.ml.regressions.linear.LinearRegressionLSQRTrainer) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) SimpleGDUpdateCalculator(org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator) AdaptableDatasetModel(org.apache.ignite.ml.trainers.AdaptableDatasetModel) StackedDatasetTrainer(org.apache.ignite.ml.composition.stacking.StackedDatasetTrainer) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) SmoothParametrized(org.apache.ignite.ml.optimization.SmoothParametrized) UpdatesStrategy(org.apache.ignite.ml.nn.UpdatesStrategy) TrainerTest(org.apache.ignite.ml.common.TrainerTest) Test(org.junit.Test)

Example 7 with DenseMatrix

use of org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix in project ignite by apache.

the class BlasTest method testSyrDenseDense.

/**
 * Tests 'syr' operation for dense vector x and dense matrix A.
 */
@Test
public void testSyrDenseDense() {
    double alpha = 2.0;
    DenseVector x = new DenseVector(new double[] { 1.0, 2.0 });
    DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 20.0 }, { 20.0, 10.0 } });
    // alpha * x * x^T + A
    DenseMatrix exp = (DenseMatrix) new DenseMatrix(new double[][] { { 1.0, 2.0 }, { 2.0, 4.0 } }).times(alpha).plus(a);
    Blas.syr(alpha, x, a);
    Assert.assertEquals(exp, a);
}
Also used : DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 8 with DenseMatrix

use of org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix in project ignite by apache.

the class BlasTest method testGemmDenseDenseDense.

/**
 * Tests 'gemm' operation for dense matrix A, dense matrix B and dense matrix C.
 */
@Test
public void testGemmDenseDenseDense() {
    // C := alpha * A * B + beta * C
    double alpha = 1.0;
    DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 11.0 }, { 0.0, 1.0 } });
    DenseMatrix b = new DenseMatrix(new double[][] { { 1.0, 0.3 }, { 0.0, 1.0 } });
    double beta = 0.0;
    DenseMatrix c = new DenseMatrix(new double[][] { { 1.0, 2.0 }, { 2.0, 3.0 } });
    // .times(alpha).plus(c.times(beta));
    DenseMatrix exp = (DenseMatrix) a.times(b);
    Blas.gemm(alpha, a, b, beta, c);
    Assert.assertEquals(exp, c);
}
Also used : DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 9 with DenseMatrix

use of org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix in project ignite by apache.

the class BlasTest method testSprDenseDense.

/**
 * Test 'spr' operation for dense vector v and dense matrix A.
 */
@Test
public void testSprDenseDense() {
    double alpha = 3.0;
    DenseVector v = new DenseVector(new double[] { 1.0, 2.0 });
    DenseVector u = new DenseVector(new double[] { 3.0, 13.0, 20.0, 0.0 });
    // m is alpha * v * v^t
    DenseMatrix m = (DenseMatrix) new DenseMatrix(new double[][] { { 1.0, 0.0 }, { 2.0, 4.0 } }, StorageConstants.COLUMN_STORAGE_MODE).times(alpha);
    DenseMatrix a = new DenseMatrix(new double[][] { { 3.0, 0.0 }, { 13.0, 20.0 } }, StorageConstants.COLUMN_STORAGE_MODE);
    // m := alpha * v * v.t + A
    Blas.spr(alpha, v, u);
    DenseMatrix mu = fromVector(u, a.rowSize(), StorageConstants.COLUMN_STORAGE_MODE, (i, j) -> i >= j);
    Assert.assertEquals(m.plus(a), mu);
}
Also used : DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 10 with DenseMatrix

use of org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix in project ignite by apache.

the class BlasTest method testGemvSparseSparseDense.

/**
 * Tests 'gemv' operation for sparse matrix A, sparse vector x and dense vector y.
 */
@Test
public void testGemvSparseSparseDense() {
    // y := alpha * A * x + beta * y
    double alpha = 3.0;
    DenseMatrix a = new DenseMatrix(new double[][] { { 10.0, 11.0 }, { 0.0, 1.0 } }, 2);
    SparseVector x = sparseFromArray(new double[] { 1.0, 2.0 });
    double beta = 2.0;
    DenseVector y = new DenseVector(new double[] { 3.0, 4.0 });
    DenseVector exp = (DenseVector) y.times(beta).plus(a.times(x).times(alpha));
    Blas.gemv(alpha, a, x, beta, y);
    Assert.assertEquals(exp, y);
}
Also used : SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Aggregations

DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)39 Test (org.junit.Test)28 Matrix (org.apache.ignite.ml.math.primitives.matrix.Matrix)14 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)14 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)9 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)8 SparseVector (org.apache.ignite.ml.math.primitives.vector.impl.SparseVector)4 MultivariateGaussianDistribution (org.apache.ignite.ml.math.stat.MultivariateGaussianDistribution)3 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)3 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)2 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)2 EuclideanDistance (org.apache.ignite.ml.math.distances.EuclideanDistance)2 HammingDistance (org.apache.ignite.ml.math.distances.HammingDistance)2 ManhattanDistance (org.apache.ignite.ml.math.distances.ManhattanDistance)2 ViewMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.ViewMatrix)2 VectorizedViewMatrix (org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix)2 MLPTrainer (org.apache.ignite.ml.nn.MLPTrainer)2 SimpleGDParameterUpdate (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate)2 SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)2 Path (java.nio.file.Path)1