use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class BackTrackLineSearchTest method testBackTrackLineGradientDescent.
///////////////////////////////////////////////////////////////////////////
@Test
public void testBackTrackLineGradientDescent() {
OptimizationAlgorithm optimizer = OptimizationAlgorithm.LINE_GRADIENT_DESCENT;
DataSetIterator irisIter = new IrisDataSetIterator(1, 1);
DataSet data = irisIter.next();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.SIGMOID, 100, optimizer));
network.init();
IterationListener listener = new ScoreIterationListener(1);
network.setListeners(Collections.singletonList(listener));
double oldScore = network.score(data);
network.fit(data.getFeatureMatrix(), data.getLabels());
double score = network.score();
assertTrue(score < oldScore);
}
use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class TestOptimizers method testOptimizersMLP.
@Test
public void testOptimizersMLP() {
//Check that the score actually decreases over time
DataSetIterator iter = new IrisDataSetIterator(150, 150);
OptimizationAlgorithm[] toTest = { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
DataSet ds = iter.next();
ds.normalizeZeroMeanZeroUnitVariance();
for (OptimizationAlgorithm oa : toTest) {
int nIter = 10;
MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa, nIter));
network.init();
double score = network.score(ds);
assertTrue(score != 0.0 && !Double.isNaN(score));
if (PRINT_OPT_RESULTS)
System.out.println("testOptimizersMLP() - " + oa);
int nCallsToOptimizer = 30;
double[] scores = new double[nCallsToOptimizer + 1];
scores[0] = score;
for (int i = 0; i < nCallsToOptimizer; i++) {
network.fit(ds);
double scoreAfter = network.score(ds);
scores[i + 1] = scoreAfter;
assertTrue("Score is NaN after optimization", !Double.isNaN(scoreAfter));
assertTrue("OA= " + oa + ", before= " + score + ", after= " + scoreAfter, scoreAfter <= score);
score = scoreAfter;
}
if (PRINT_OPT_RESULTS)
System.out.println(oa + " - " + Arrays.toString(scores));
}
}
use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class TestDecayPolicies method testUpdatingInConf.
@Test
public void testUpdatingInConf() throws Exception {
OptimizationAlgorithm[] optimizationAlgorithms = new OptimizationAlgorithm[] { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
for (OptimizationAlgorithm oa : optimizationAlgorithms) {
Map<Integer, Double> momentumSchedule = new HashMap<>();
double m = 0.001;
for (int i = 0; i <= 100; i++) {
momentumSchedule.put(i, Math.min(m, 0.9999));
m += 0.001;
}
Map<Integer, Double> learningRateSchedule = new HashMap<>();
double lr = 0.1;
for (int i = 0; i <= 100; i++) {
learningRateSchedule.put(i, lr);
lr *= 0.96;
}
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(oa).iterations(1).learningRateDecayPolicy(LearningRatePolicy.Schedule).learningRateSchedule(learningRateSchedule).updater(org.deeplearning4j.nn.conf.Updater.NESTEROVS).weightInit(WeightInit.XAVIER).momentum(0.9).momentumAfter(momentumSchedule).regularization(true).l2(0.0001).list().layer(0, new DenseLayer.Builder().nIn(784).nOut(10).build()).layer(1, new OutputLayer.Builder().nIn(10).nOut(10).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
int last_layer_index = 1;
DataSetIterator trainIter = new MnistDataSetIterator(64, true, 12345);
int count = 0;
while (trainIter.hasNext()) {
net.fit(trainIter.next());
// always print the same number (0.1 and 0.9)
double lrLastLayer = (net.getLayer(last_layer_index)).conf().getLearningRateByParam("W");
double mLastLayer = (net.getLayer(last_layer_index)).conf().getLayer().getMomentum();
assertEquals(learningRateSchedule.get(count), lrLastLayer, 1e-6);
assertEquals(momentumSchedule.get(count), mLastLayer, 1e-6);
if (count++ >= 100)
break;
}
}
}
Aggregations