Search in sources :

Example 1 with FineTuneConfiguration

use of org.deeplearning4j.nn.transferlearning.FineTuneConfiguration in project deeplearning4j by deeplearning4j.

the class FrozenLayerTest method testFrozen.

/*
        A model with a few frozen layers ==
            Model with non frozen layers set with the output of the forward pass of the frozen layers
     */
@Test
public void testFrozen() {
    DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).activation(Activation.IDENTITY);
    FineTuneConfiguration finetune = new FineTuneConfiguration.Builder().learningRate(0.1).build();
    MultiLayerNetwork modelToFineTune = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new DenseLayer.Builder().nIn(3).nOut(2).build()).layer(2, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(3, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build());
    modelToFineTune.init();
    List<INDArray> ff = modelToFineTune.feedForwardToLayer(2, randomData.getFeatures(), false);
    INDArray asFrozenFeatures = ff.get(2);
    MultiLayerNetwork modelNow = new TransferLearning.Builder(modelToFineTune).fineTuneConfiguration(finetune).setFeatureExtractor(1).build();
    INDArray paramsLastTwoLayers = Nd4j.hstack(modelToFineTune.getLayer(2).params(), modelToFineTune.getLayer(3).params());
    MultiLayerNetwork notFrozen = new MultiLayerNetwork(overallConf.clone().list().layer(0, new DenseLayer.Builder().nIn(2).nOut(3).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build()).build(), paramsLastTwoLayers);
    //        assertEquals(modelNow.getLayer(2).conf(), notFrozen.getLayer(0).conf());  //Equal, other than names
    //        assertEquals(modelNow.getLayer(3).conf(), notFrozen.getLayer(1).conf());  //Equal, other than names
    //Check: forward pass
    INDArray outNow = modelNow.output(randomData.getFeatures());
    INDArray outNotFrozen = notFrozen.output(asFrozenFeatures);
    assertEquals(outNow, outNotFrozen);
    for (int i = 0; i < 5; i++) {
        notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
        modelNow.fit(randomData);
    }
    INDArray expected = Nd4j.hstack(modelToFineTune.getLayer(0).params(), modelToFineTune.getLayer(1).params(), notFrozen.params());
    INDArray act = modelNow.params();
    assertEquals(expected, act);
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork) FineTuneConfiguration(org.deeplearning4j.nn.transferlearning.FineTuneConfiguration) TransferLearning(org.deeplearning4j.nn.transferlearning.TransferLearning) Test(org.junit.Test)

Aggregations

NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)1 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)1 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)1 FineTuneConfiguration (org.deeplearning4j.nn.transferlearning.FineTuneConfiguration)1 TransferLearning (org.deeplearning4j.nn.transferlearning.TransferLearning)1 Test (org.junit.Test)1 INDArray (org.nd4j.linalg.api.ndarray.INDArray)1 DataSet (org.nd4j.linalg.dataset.DataSet)1