Search in sources :

Example 1 with AdaGrad

use of org.nd4j.linalg.learning.config.AdaGrad in project nd4j by deeplearning4j.

the class UpdaterTest method testAdaGrad.

@Test
public void testAdaGrad() {
    int rows = 10;
    int cols = 2;
    AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
    grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new int[] { rows, cols }, 'c', true);
    INDArray W = Nd4j.zeros(rows, cols);
    Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
    for (int i = 0; i < W.rows(); i++) W.putRow(i, Nd4j.create(dist.sample(W.columns())));
    for (int i = 0; i < 5; i++) {
        // String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
        // System.out.println(learningRates);
        W.addi(Nd4j.randn(rows, cols));
    }
}
Also used : INDArray(org.nd4j.linalg.api.ndarray.INDArray) Distribution(org.nd4j.linalg.api.rng.distribution.Distribution) AdaGrad(org.nd4j.linalg.learning.config.AdaGrad) Test(org.junit.Test) BaseNd4jTest(org.nd4j.linalg.BaseNd4jTest)

Aggregations

Test (org.junit.Test)1 BaseNd4jTest (org.nd4j.linalg.BaseNd4jTest)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 Distribution (org.nd4j.linalg.api.rng.distribution.Distribution)1 AdaGrad (org.nd4j.linalg.learning.config.AdaGrad)1