Search in sources :

Example 11 with Matrix

use of org.apache.ignite.ml.math.primitives.matrix.Matrix in project ignite by apache.

the class MLPTest method testDifferentiation.

/**
 * Test differentiation.
 */
@Test
public void testDifferentiation() {
    int inputSize = 2;
    int firstLayerNeuronsCnt = 1;
    double w10 = 0.1;
    double w11 = 0.2;
    MLPArchitecture conf = new MLPArchitecture(inputSize).withAddedLayer(firstLayerNeuronsCnt, false, Activators.SIGMOID);
    MultilayerPerceptron mlp1 = new MultilayerPerceptron(conf);
    mlp1.setWeight(1, 0, 0, w10);
    MultilayerPerceptron mlp = mlp1.setWeight(1, 1, 0, w11);
    double x0 = 1.0;
    double x1 = 3.0;
    Matrix inputs = new DenseMatrix(new double[][] { { x0, x1 } }).transpose();
    double ytt = 1.0;
    Matrix truth = new DenseMatrix(new double[][] { { ytt } }).transpose();
    Vector grad = mlp.differentiateByParameters(LossFunctions.MSE, inputs, truth);
    // Let yt be y ground truth value.
    // d/dw1i [(yt - sigma(w10 * x0 + w11 * x1))^2] =
    // 2 * (yt - sigma(w10 * x0 + w11 * x1)) * (-1) * (sigma(w10 * x0 + w11 * x1)) * (1 - sigma(w10 * x0 + w11 * x1)) * xi =
    // let z = sigma(w10 * x0 + w11 * x1)
    // - 2* (yt - z) * (z) * (1 - z) * xi.
    IgniteTriFunction<Double, Vector, Vector, Vector> partialDer = (yt, w, x) -> {
        Double z = Activators.SIGMOID.apply(w.dot(x));
        return x.copy().map(xi -> -2 * (yt - z) * z * (1 - z) * xi);
    };
    Vector weightsVec = mlp.weights(1).getRow(0);
    Tracer.showAscii(weightsVec);
    Vector trueGrad = partialDer.andThen(x -> x).apply(ytt, weightsVec, inputs.getCol(0));
    Tracer.showAscii(trueGrad);
    Tracer.showAscii(grad);
    Assert.assertEquals(mlp.architecture().parametersCount(), grad.size());
    Assert.assertEquals(trueGrad, grad);
}
Also used : Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) TestUtils(org.apache.ignite.ml.TestUtils) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) IgniteTriFunction(org.apache.ignite.ml.math.functions.IgniteTriFunction) LossFunctions(org.apache.ignite.ml.optimization.LossFunctions) Tracer(org.apache.ignite.ml.math.Tracer) Test(org.junit.Test) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Assert(org.junit.Assert) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) Vector(org.apache.ignite.ml.math.primitives.vector.Vector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 12 with Matrix

use of org.apache.ignite.ml.math.primitives.matrix.Matrix in project ignite by apache.

the class MLPTest method testSimpleMLPPrediction.

/**
 * Tests that MLP with 2 layer, 1 neuron in each layer and weight equal to 1 is equivalent to sigmoid function.
 */
@Test
public void testSimpleMLPPrediction() {
    MLPArchitecture conf = new MLPArchitecture(1).withAddedLayer(1, false, Activators.SIGMOID);
    MultilayerPerceptron mlp = new MultilayerPerceptron(conf, new MLPConstInitializer(1));
    int input = 2;
    Matrix predict = mlp.predict(new DenseMatrix(new double[][] { { input } }));
    Assert.assertEquals(predict, new DenseMatrix(new double[][] { { Activators.SIGMOID.apply(input) } }));
}
Also used : Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Test(org.junit.Test)

Example 13 with Matrix

use of org.apache.ignite.ml.math.primitives.matrix.Matrix in project ignite by apache.

the class MLPTrainerIntegrationTest method xorTest.

/**
 * Common method for testing 'XOR' with various updaters.
 *
 * @param updatesStgy Update strategy.
 * @param <P> Updater parameters type.
 */
private <P extends Serializable> void xorTest(UpdatesStrategy<? super MultilayerPerceptron, P> updatesStgy) {
    CacheConfiguration<Integer, LabeledVector<double[]>> xorCacheCfg = new CacheConfiguration<>();
    xorCacheCfg.setName("XorData");
    xorCacheCfg.setAffinity(new RendezvousAffinityFunction(false, 5));
    IgniteCache<Integer, LabeledVector<double[]>> xorCache = ignite.createCache(xorCacheCfg);
    try {
        xorCache.put(0, VectorUtils.of(0.0, 0.0).labeled(new double[] { 0.0 }));
        xorCache.put(1, VectorUtils.of(0.0, 1.0).labeled(new double[] { 1.0 }));
        xorCache.put(2, VectorUtils.of(1.0, 0.0).labeled(new double[] { 1.0 }));
        xorCache.put(3, VectorUtils.of(1.0, 1.0).labeled(new double[] { 0.0 }));
        MLPArchitecture arch = new MLPArchitecture(2).withAddedLayer(10, true, Activators.RELU).withAddedLayer(1, false, Activators.SIGMOID);
        MLPTrainer<P> trainer = new MLPTrainer<>(arch, LossFunctions.MSE, updatesStgy, 2500, 4, 50, 123L);
        MultilayerPerceptron mlp = trainer.fit(ignite, xorCache, new LabeledDummyVectorizer<>());
        Matrix predict = mlp.predict(new DenseMatrix(new double[][] { { 0.0, 0.0 }, { 0.0, 1.0 }, { 1.0, 0.0 }, { 1.0, 1.0 } }));
        Tracer.showAscii(predict);
        X.println(new DenseVector(new double[] { 0.0 }).minus(predict.getRow(0)).kNorm(2) + "");
        TestUtils.checkIsInEpsilonNeighbourhood(new DenseVector(new double[] { 0.0 }), predict.getRow(0), 1E-1);
    } finally {
        xorCache.destroy();
    }
}
Also used : MLPArchitecture(org.apache.ignite.ml.nn.architecture.MLPArchitecture) LabeledVector(org.apache.ignite.ml.structures.LabeledVector) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) RendezvousAffinityFunction(org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction) CacheConfiguration(org.apache.ignite.configuration.CacheConfiguration) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)

Example 14 with Matrix

use of org.apache.ignite.ml.math.primitives.matrix.Matrix in project ignite by apache.

the class VectorToMatrixTest method testToMatrix.

/**
 */
@Test
public void testToMatrix() {
    consumeSampleVectors((v, desc) -> {
        if (!availableForTesting(v))
            return;
        fillWithNonZeroes(v);
        final Matrix matrixRow = v.toMatrix(true);
        final Matrix matrixCol = v.toMatrix(false);
        for (Vector.Element e : v.all()) assertToMatrixValue(desc, matrixRow, matrixCol, e.get(), e.index());
    });
}
Also used : Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) VectorizedViewMatrix(org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) SparseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.SparseMatrix) DelegatingVector(org.apache.ignite.ml.math.primitives.vector.impl.DelegatingVector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Example 15 with Matrix

use of org.apache.ignite.ml.math.primitives.matrix.Matrix in project ignite by apache.

the class VectorToMatrixTest method testToMatrixPlusOne.

/**
 */
@Test
public void testToMatrixPlusOne() {
    consumeSampleVectors((v, desc) -> {
        if (!availableForTesting(v))
            return;
        fillWithNonZeroes(v);
        for (double zeroVal : new double[] { -1, 0, 1, 2 }) {
            final Matrix matrixRow = v.toMatrixPlusOne(true, zeroVal);
            final Matrix matrixCol = v.toMatrixPlusOne(false, zeroVal);
            final Metric metricRow0 = new Metric(zeroVal, matrixRow.get(0, 0));
            assertTrue("Not close enough row like " + metricRow0 + " at index 0 in " + desc, metricRow0.closeEnough());
            final Metric metricCol0 = new Metric(zeroVal, matrixCol.get(0, 0));
            assertTrue("Not close enough cols like " + metricCol0 + " at index 0 in " + desc, metricCol0.closeEnough());
            for (Vector.Element e : v.all()) assertToMatrixValue(desc, matrixRow, matrixCol, e.get(), e.index() + 1);
        }
    });
}
Also used : Matrix(org.apache.ignite.ml.math.primitives.matrix.Matrix) VectorizedViewMatrix(org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix) DenseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix) SparseMatrix(org.apache.ignite.ml.math.primitives.matrix.impl.SparseMatrix) DelegatingVector(org.apache.ignite.ml.math.primitives.vector.impl.DelegatingVector) SparseVector(org.apache.ignite.ml.math.primitives.vector.impl.SparseVector) DenseVector(org.apache.ignite.ml.math.primitives.vector.impl.DenseVector) Test(org.junit.Test)

Aggregations

Matrix (org.apache.ignite.ml.math.primitives.matrix.Matrix)31 DenseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.DenseMatrix)24 Test (org.junit.Test)12 Vector (org.apache.ignite.ml.math.primitives.vector.Vector)8 MLPArchitecture (org.apache.ignite.ml.nn.architecture.MLPArchitecture)8 DenseVector (org.apache.ignite.ml.math.primitives.vector.impl.DenseVector)7 SparseMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.SparseMatrix)6 VectorizedViewMatrix (org.apache.ignite.ml.math.primitives.vector.impl.VectorizedViewMatrix)5 RendezvousAffinityFunction (org.apache.ignite.cache.affinity.rendezvous.RendezvousAffinityFunction)2 CacheConfiguration (org.apache.ignite.configuration.CacheConfiguration)2 ViewMatrix (org.apache.ignite.ml.math.primitives.matrix.impl.ViewMatrix)2 DelegatingVector (org.apache.ignite.ml.math.primitives.vector.impl.DelegatingVector)2 SparseVector (org.apache.ignite.ml.math.primitives.vector.impl.SparseVector)2 MLPTrainer (org.apache.ignite.ml.nn.MLPTrainer)2 SimpleGDParameterUpdate (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDParameterUpdate)2 SimpleGDUpdateCalculator (org.apache.ignite.ml.optimization.updatecalculators.SimpleGDUpdateCalculator)2 LabeledVector (org.apache.ignite.ml.structures.LabeledVector)2 ArrayList (java.util.ArrayList)1 LinkedList (java.util.LinkedList)1 Random (java.util.Random)1