Search in sources :

Example 1 with DefaultStepFunction

use of org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction in project deeplearning4j by deeplearning4j.

the class BackTrackLineSearchTest method testMultMaxLineSearch.

@Test
public void testMultMaxLineSearch() throws Exception {
    double score1, score2;
    irisData.normalizeZeroMeanZeroUnitVariance();
    OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.MCXENT);
    int nParams = layer.numParams();
    layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
    layer.setInput(irisData.getFeatureMatrix());
    layer.setLabels(irisData.getLabels());
    layer.computeGradientAndScore();
    score1 = layer.score();
    INDArray origGradient = layer.gradient().gradient().dup();
    DefaultStepFunction sf = new DefaultStepFunction();
    BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
    double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(), layer.gradient().gradient().dup());
    INDArray currParams = layer.params();
    sf.step(currParams, origGradient, step);
    layer.setParams(currParams);
    layer.computeGradientAndScore();
    score2 = layer.score();
    assertTrue("score1 = " + score1 + ", score2 = " + score2, score1 < score2);
}
Also used : OutputLayer(org.deeplearning4j.nn.layers.OutputLayer) BackTrackLineSearch(org.deeplearning4j.optimize.solvers.BackTrackLineSearch) INDArray(org.nd4j.linalg.api.ndarray.INDArray) NegativeDefaultStepFunction(org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction) DefaultStepFunction(org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction) Test(org.junit.Test)

Aggregations

OutputLayer (org.deeplearning4j.nn.layers.OutputLayer)1 BackTrackLineSearch (org.deeplearning4j.optimize.solvers.BackTrackLineSearch)1 DefaultStepFunction (org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction)1 NegativeDefaultStepFunction (org.deeplearning4j.optimize.stepfunctions.NegativeDefaultStepFunction)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1