use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class BackTrackLineSearchTest method testBackTrackLineLBFGS.
@Test
public void testBackTrackLineLBFGS() {
OptimizationAlgorithm optimizer = OptimizationAlgorithm.LBFGS;
DataSet data = irisIter.next();
data.normalizeZeroMeanZeroUnitVariance();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, 5, optimizer));
network.init();
IterationListener listener = new ScoreIterationListener(1);
network.setListeners(Collections.singletonList(listener));
double oldScore = network.score(data);
network.fit(data.getFeatureMatrix(), data.getLabels());
double score = network.score();
assertTrue(score < oldScore);
}
use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class BackTrackLineSearchTest method testBackTrackLineHessian.
@Test(expected = Exception.class)
public void testBackTrackLineHessian() {
OptimizationAlgorithm optimizer = OptimizationAlgorithm.HESSIAN_FREE;
DataSet data = irisIter.next();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, 100, optimizer));
network.init();
IterationListener listener = new ScoreIterationListener(1);
network.setListeners(Collections.singletonList(listener));
network.fit(data.getFeatureMatrix(), data.getLabels());
}
use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class TestOptimizers method testOptimizersBasicMLPBackprop.
@Test
public void testOptimizersBasicMLPBackprop() {
//Basic tests of the 'does it throw an exception' variety.
DataSetIterator iter = new IrisDataSetIterator(5, 50);
OptimizationAlgorithm[] toTest = { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
for (OptimizationAlgorithm oa : toTest) {
MultiLayerNetwork network = new MultiLayerNetwork(getMLPConfigIris(oa, 1));
network.init();
iter.reset();
network.fit(iter);
}
}
use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class TestComputationGraphNetwork method testOptimizationAlgorithmsSearchBasic.
@Test
public void testOptimizationAlgorithmsSearchBasic() {
DataSetIterator iter = new IrisDataSetIterator(1, 1);
OptimizationAlgorithm[] oas = new OptimizationAlgorithm[] { OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT, OptimizationAlgorithm.LINE_GRADIENT_DESCENT, OptimizationAlgorithm.CONJUGATE_GRADIENT, OptimizationAlgorithm.LBFGS };
for (OptimizationAlgorithm oa : oas) {
System.out.println(oa);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(oa).iterations(1).graphBuilder().addInputs("input").addLayer("first", new DenseLayer.Builder().nIn(4).nOut(5).build(), "input").addLayer("output", new OutputLayer.Builder().nIn(5).nOut(3).build(), "first").setOutputs("output").pretrain(false).backprop(true).build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
net.fit(iter.next());
}
}
use of org.deeplearning4j.nn.api.OptimizationAlgorithm in project deeplearning4j by deeplearning4j.
the class BackTrackLineSearchTest method testBackTrackLineCG.
@Test
public void testBackTrackLineCG() {
OptimizationAlgorithm optimizer = OptimizationAlgorithm.CONJUGATE_GRADIENT;
DataSet data = irisIter.next();
data.normalizeZeroMeanZeroUnitVariance();
MultiLayerNetwork network = new MultiLayerNetwork(getIrisMultiLayerConfig(Activation.RELU, 5, optimizer));
network.init();
IterationListener listener = new ScoreIterationListener(1);
network.setListeners(Collections.singletonList(listener));
double firstScore = network.score(data);
network.fit(data.getFeatureMatrix(), data.getLabels());
double score = network.score();
assertTrue(score < firstScore);
}
Aggregations