use of org.deeplearning4j.optimize.stepfunctions.DefaultStepFunction in project deeplearning4j by deeplearning4j.
the class BackTrackLineSearchTest method testMultMaxLineSearch.
@Test
public void testMultMaxLineSearch() throws Exception {
double score1, score2;
irisData.normalizeZeroMeanZeroUnitVariance();
OutputLayer layer = getIrisLogisticLayerConfig(Activation.SOFTMAX, 100, LossFunctions.LossFunction.MCXENT);
int nParams = layer.numParams();
layer.setBackpropGradientsViewArray(Nd4j.create(1, nParams));
layer.setInput(irisData.getFeatureMatrix());
layer.setLabels(irisData.getLabels());
layer.computeGradientAndScore();
score1 = layer.score();
INDArray origGradient = layer.gradient().gradient().dup();
DefaultStepFunction sf = new DefaultStepFunction();
BackTrackLineSearch lineSearch = new BackTrackLineSearch(layer, sf, layer.getOptimizer());
double step = lineSearch.optimize(layer.params().dup(), layer.gradient().gradient().dup(), layer.gradient().gradient().dup());
INDArray currParams = layer.params();
sf.step(currParams, origGradient, step);
layer.setParams(currParams);
layer.computeGradientAndScore();
score2 = layer.score();
assertTrue("score1 = " + score1 + ", score2 = " + score2, score1 < score2);
}
Aggregations