Search in sources :

Example 16 with Distribution

use of org.nd4j.linalg.api.rng.distribution.Distribution in project nd4j by deeplearning4j.

the class UpdaterTest method testAdaGrad.

@Test
public void testAdaGrad() {
    int rows = 10;
    int cols = 2;
    AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
    grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new int[] { rows, cols }, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
    for (int i = 0; i < W.rows(); i++) W.putRow(i, Nd4j.create(dist.sample(W.columns())));
    for (int i = 0; i < 5; i++) {
        // String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        // System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) AdaGrad(org.nd4j.linalg.learning.config.AdaGrad) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Example 17 with Distribution

use of org.nd4j.linalg.api.rng.distribution.Distribution in project nd4j by deeplearning4j.

the class UpdaterTest method testAdaMax.

@Test
public void testAdaMax() {
    int rows = 10;
    int cols = 2;
    AdaMaxUpdater grad = new AdaMaxUpdater(new AdaMax());
    grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new int[] { rows, cols }, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
    for (int i = 0; i < W.rows(); i++) W.putRow(i, Nd4j.create(dist.sample(W.columns())));
    for (int i = 0; i < 5; i++) {
        // String learningRates = String.valueOf("\nAdaMax\n " + grad.getGradient(W, i)).replaceAll(";", "\n");
        // System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

Distribution (org.nd4j.linalg.api.rng.distribution.Distribution)17 INDArray (org.nd4j.linalg.api.ndarray.INDArray)16 Test (org.junit.Test)9 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)7 Pair (org.deeplearning4j.berkeley.Pair)2 LinkedHashMap (java.util.LinkedHashMap)1 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)1 VariationalAutoencoder (org.deeplearning4j.nn.conf.layers.variational.VariationalAutoencoder)1 Gradient (org.deeplearning4j.nn.gradient.Gradient)1 WeightInit (org.deeplearning4j.nn.weights.WeightInit)1 ActivationSigmoid (org.nd4j.linalg.activations.impl.ActivationSigmoid)1 MatchCondition (org.nd4j.linalg.api.ops.impl.accum.MatchCondition)1 DefaultRandom (org.nd4j.linalg.api.rng.DefaultRandom)1 Random (org.nd4j.linalg.api.rng.Random)1 NormalDistribution (org.nd4j.linalg.api.rng.distribution.impl.NormalDistribution)1 OrthogonalDistribution (org.nd4j.linalg.api.rng.distribution.impl.OrthogonalDistribution)1 AdaGrad (org.nd4j.linalg.learning.config.AdaGrad)1 NativeRandom (org.nd4j.rng.NativeRandom)1