use of org.nd4j.linalg.api.rng.distribution.Distribution in project nd4j by deeplearning4j.
the class UpdaterTest method testAdaGrad.
@Test
public void testAdaGrad() {
int rows = 10;
int cols = 2;
AdaGradUpdater grad = new AdaGradUpdater(new AdaGrad(0.1, AdaGrad.DEFAULT_ADAGRAD_EPSILON));
grad.setStateViewArray(Nd4j.zeros(1, rows * cols), new int[] { rows, cols }, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1, 1);
for (int i = 0; i < W.rows(); i++) W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdagrad\n " + grad.applyUpdater(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
use of org.nd4j.linalg.api.rng.distribution.Distribution in project nd4j by deeplearning4j.
the class UpdaterTest method testAdaMax.
@Test
public void testAdaMax() {
int rows = 10;
int cols = 2;
AdaMaxUpdater grad = new AdaMaxUpdater(new AdaMax());
grad.setStateViewArray(Nd4j.zeros(1, 2 * rows * cols), new int[] { rows, cols }, 'c', true);
INDArray W = Nd4j.zeros(rows, cols);
Distribution dist = Nd4j.getDistributions().createNormal(1e-3, 1e-3);
for (int i = 0; i < W.rows(); i++) W.putRow(i, Nd4j.create(dist.sample(W.columns())));
for (int i = 0; i < 5; i++) {
// String learningRates = String.valueOf("\nAdaMax\n " + grad.getGradient(W, i)).replaceAll(";", "\n");
// System.out.println(learningRates);
W.addi(Nd4j.randn(rows, cols));
}
}
Aggregations