Search in sources :

Example 31 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class ScoreFlatMapFunctionCGMultiDataSetAdapter method call.

@Override
public Iterable<Tuple2<Integer, Double>> call(Iterator<MultiDataSet> dataSetIterator) throws Exception {
    if (!dataSetIterator.hasNext()) {
        return Collections.singletonList(new Tuple2<>(0, 0.0));
    }
    //Does batching where appropriate
    MultiDataSetIterator iter = new IteratorMultiDataSetIterator(dataSetIterator, minibatchSize);
    ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(json));
    network.init();
    //.value() is shared by all executors on single machine -> OK, as params are not changed in score function
    INDArray val = params.value().unsafeDuplication();
    if (val.length() != network.numParams(false))
        throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
    network.setParams(val);
    List<Tuple2<Integer, Double>> out = new ArrayList<>();
    while (iter.hasNext()) {
        MultiDataSet ds = iter.next();
        double score = network.score(ds, false);
        int numExamples = ds.getFeatures(0).size(0);
        out.add(new Tuple2<>(numExamples, score * numExamples));
    }
    if (Nd4j.getExecutioner() instanceof GridExecutioner)
        ((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
    return out;
}
Also used : IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) ArrayList(java.util.ArrayList) IteratorMultiDataSetIterator(org.deeplearning4j.datasets.iterator.IteratorMultiDataSetIterator) MultiDataSetIterator(org.nd4j.linalg.dataset.api.iterator.MultiDataSetIterator) GridExecutioner(org.nd4j.linalg.api.ops.executioner.GridExecutioner) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.api.MultiDataSet) Tuple2(scala.Tuple2) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph)

Example 32 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TransferLearningHelperTest method testFitUnFrozen.

@Test
public void testFitUnFrozen() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).seed(124).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(10).nOut(9).build(), "inCentre").addLayer("denseCentre1", new DenseLayer.Builder().nIn(9).nOut(8).build(), "denseCentre0").addLayer("denseCentre2", new DenseLayer.Builder().nIn(8).nOut(7).build(), "denseCentre1").addLayer("denseCentre3", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("outCentre", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(7).nOut(4).build(), "denseCentre3").addVertex("subsetLeft", new SubsetVertex(0, 3), "denseCentre1").addLayer("denseLeft0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "subsetLeft").addLayer("outLeft", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(6).build(), "denseLeft0").addLayer("denseRight", new DenseLayer.Builder().nIn(7).nOut(7).build(), "denseCentre2").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseRight", "denseRight0").addLayer("denseRight1", new DenseLayer.Builder().nIn(10).nOut(5).build(), "mergeRight").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(5).nOut(5).build(), "denseRight1").setOutputs("outLeft", "outCentre", "outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    INDArray inRight = Nd4j.rand(10, 2);
    INDArray inCentre = Nd4j.rand(10, 10);
    INDArray outLeft = Nd4j.rand(10, 6);
    INDArray outRight = Nd4j.rand(10, 5);
    INDArray outCentre = Nd4j.rand(10, 4);
    MultiDataSet origData = new MultiDataSet(new INDArray[] { inCentre, inRight }, new INDArray[] { outLeft, outCentre, outRight });
    ComputationGraph modelIdentical = modelToTune.clone();
    modelIdentical.getVertex("denseCentre0").setLayerAsFrozen();
    modelIdentical.getVertex("denseCentre1").setLayerAsFrozen();
    modelIdentical.getVertex("denseCentre2").setLayerAsFrozen();
    TransferLearningHelper helper = new TransferLearningHelper(modelToTune, "denseCentre2");
    MultiDataSet featurizedDataSet = helper.featurize(origData);
    assertEquals(modelIdentical.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
    modelIdentical.fit(origData);
    helper.fitFeaturized(featurizedDataSet);
    assertEquals(modelIdentical.getLayer("denseCentre0").params(), modelToTune.getLayer("denseCentre0").params());
    assertEquals(modelIdentical.getLayer("denseCentre1").params(), modelToTune.getLayer("denseCentre1").params());
    assertEquals(modelIdentical.getLayer("denseCentre2").params(), modelToTune.getLayer("denseCentre2").params());
    assertEquals(modelIdentical.getLayer("denseCentre3").params(), modelToTune.getLayer("denseCentre3").params());
    assertEquals(modelIdentical.getLayer("outCentre").params(), modelToTune.getLayer("outCentre").params());
    assertEquals(modelIdentical.getLayer("denseRight").conf().toJson(), modelToTune.getLayer("denseRight").conf().toJson());
    assertEquals(modelIdentical.getLayer("denseRight").params(), modelToTune.getLayer("denseRight").params());
    assertEquals(modelIdentical.getLayer("denseRight0").conf().toJson(), modelToTune.getLayer("denseRight0").conf().toJson());
    //assertEquals(modelIdentical.getLayer("denseRight0").params(),modelToTune.getLayer("denseRight0").params());
    assertEquals(modelIdentical.getLayer("denseRight1").params(), modelToTune.getLayer("denseRight1").params());
    assertEquals(modelIdentical.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
    assertEquals(modelIdentical.getLayer("denseLeft0").params(), modelToTune.getLayer("denseLeft0").params());
    assertEquals(modelIdentical.getLayer("outLeft").params(), modelToTune.getLayer("outLeft").params());
    log.info(modelIdentical.summary());
    log.info(helper.unfrozenGraph().summary());
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) SubsetVertex(org.deeplearning4j.nn.conf.graph.SubsetVertex) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 33 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TransferLearningCompGraphTest method testNoutChanges.

@Test
public void testNoutChanges() {
    DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 2));
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).activation(Activation.IDENTITY);
    FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().learningRate(0.1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).activation(Activation.IDENTITY).build();
    ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
    modelToFineTune.init();
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(fineTuneConfiguration).nOutReplace("layer3", 2, WeightInit.XAVIER).nOutReplace("layer0", 3, new NormalDistribution(1, 1e-1), WeightInit.XAVIER).build();
    assertEquals(modelNow.getLayer("layer0").conf().getLayer().getWeightInit(), WeightInit.DISTRIBUTION);
    assertEquals(modelNow.getLayer("layer0").conf().getLayer().getDist(), new NormalDistribution(1, 1e-1));
    assertEquals(modelNow.getLayer("layer1").conf().getLayer().getWeightInit(), WeightInit.XAVIER);
    assertEquals(modelNow.getLayer("layer1").conf().getLayer().getDist(), null);
    assertEquals(modelNow.getLayer("layer3").conf().getLayer().getWeightInit(), WeightInit.XAVIER);
    ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(2).build(), "layer2").setOutputs("layer3").build());
    modelExpectedArch.init();
    //modelNow should have the same architecture as modelExpectedArch
    assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), modelNow.getLayer("layer0").params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), modelNow.getLayer("layer1").params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), modelNow.getLayer("layer2").params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), modelNow.getLayer("layer3").params().shape());
    modelNow.setParams(modelExpectedArch.params());
    //fit should give the same results
    modelExpectedArch.fit(randomData);
    modelNow.fit(randomData);
    assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8);
    assertEquals(modelExpectedArch.params(), modelNow.params());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) NormalDistribution(org.deeplearning4j.nn.conf.distribution.NormalDistribution) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 34 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TransferLearningCompGraphTest method testRemoveAndAdd.

@Test
public void testRemoveAndAdd() {
    DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).activation(Activation.IDENTITY);
    FineTuneConfiguration fineTuneConfiguration = new FineTuneConfiguration.Builder().learningRate(0.1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).activation(Activation.IDENTITY).build();
    ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(5).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(5).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
    modelToFineTune.init();
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).fineTuneConfiguration(fineTuneConfiguration).nOutReplace("layer0", 7, WeightInit.XAVIER, WeightInit.XAVIER).nOutReplace("layer2", 5, WeightInit.XAVIER).removeVertexKeepConnections("layer3").addLayer("layer3", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(5).nOut(3).activation(Activation.SOFTMAX).build(), "layer2").build();
    ComputationGraph modelExpectedArch = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(7).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(7).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(5).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(5).nOut(3).build(), "layer2").setOutputs("layer3").build());
    modelExpectedArch.init();
    //modelNow should have the same architecture as modelExpectedArch
    assertArrayEquals(modelExpectedArch.params().shape(), modelNow.params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer0").params().shape(), modelNow.getLayer("layer0").params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer1").params().shape(), modelNow.getLayer("layer1").params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer2").params().shape(), modelNow.getLayer("layer2").params().shape());
    assertArrayEquals(modelExpectedArch.getLayer("layer3").params().shape(), modelNow.getLayer("layer3").params().shape());
    modelNow.setParams(modelExpectedArch.params());
    //fit should give the same results
    modelExpectedArch.fit(randomData);
    modelNow.fit(randomData);
    assertEquals(modelExpectedArch.score(), modelNow.score(), 1e-8);
    assertEquals(modelExpectedArch.params(), modelNow.params());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 35 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TransferLearningComplex method testSimplerMergeBackProp.

@Test
public void testSimplerMergeBackProp() {
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.9).activation(Activation.IDENTITY).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD);
    /*
                inCentre                inRight
                   |                        |
             denseCentre0               denseRight0
                   |                        |
                   |------ mergeRight ------|
                                |
                              outRight
        
        */
    ComputationGraphConfiguration conf = overallConf.graphBuilder().addInputs("inCentre", "inRight").addLayer("denseCentre0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inCentre").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
    ComputationGraph modelToTune = new ComputationGraph(conf);
    modelToTune.init();
    MultiDataSet randData = new MultiDataSet(new INDArray[] { Nd4j.rand(2, 2), Nd4j.rand(2, 2) }, new INDArray[] { Nd4j.rand(2, 2) });
    INDArray denseCentre0 = modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0");
    MultiDataSet otherRandData = new MultiDataSet(new INDArray[] { denseCentre0, randData.getFeatures(1) }, randData.getLabels());
    ComputationGraphConfiguration otherConf = overallConf.graphBuilder().addInputs("denseCentre0", "inRight").addLayer("denseRight0", new DenseLayer.Builder().nIn(2).nOut(2).build(), "inRight").addVertex("mergeRight", new MergeVertex(), "denseCentre0", "denseRight0").addLayer("outRight", new OutputLayer.Builder(LossFunctions.LossFunction.MSE).nIn(4).nOut(2).build(), "mergeRight").setOutputs("outRight").build();
    ComputationGraph modelOther = new ComputationGraph(otherConf);
    modelOther.init();
    modelOther.getLayer("denseRight0").setParams(modelToTune.getLayer("denseRight0").params());
    modelOther.getLayer("outRight").setParams(modelToTune.getLayer("outRight").params());
    modelToTune.getVertex("denseCentre0").setLayerAsFrozen();
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToTune).setFeatureExtractor("denseCentre0").build();
    int n = 0;
    while (n < 5) {
        if (n == 0) {
            //confirm activations out of the merge are equivalent
            assertEquals(modelToTune.feedForward(randData.getFeatures(), false).get("mergeRight"), modelOther.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
            assertEquals(modelNow.feedForward(randData.getFeatures(), false).get("mergeRight"), modelOther.feedForward(otherRandData.getFeatures(), false).get("mergeRight"));
        }
        //confirm activations out of frozen vertex is the same as the input to the other model
        modelOther.fit(otherRandData);
        modelToTune.fit(randData);
        modelNow.fit(randData);
        assertEquals(otherRandData.getFeatures(0), modelNow.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(otherRandData.getFeatures(0), modelToTune.feedForward(randData.getFeatures(), false).get("denseCentre0"));
        assertEquals(modelOther.getLayer("denseRight0").params(), modelNow.getLayer("denseRight0").params());
        assertEquals(modelOther.getLayer("denseRight0").params(), modelToTune.getLayer("denseRight0").params());
        assertEquals(modelOther.getLayer("outRight").params(), modelNow.getLayer("outRight").params());
        assertEquals(modelOther.getLayer("outRight").params(), modelToTune.getLayer("outRight").params());
        n++;
    }
}
Also used : DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) MultiDataSet(org.nd4j.linalg.dataset.MultiDataSet) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) MergeVertex(org.deeplearning4j.nn.conf.graph.MergeVertex) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Aggregations

ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)109 Test (org.junit.Test)73 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)62 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)36 DataSet (org.nd4j.linalg.dataset.DataSet)25 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)19 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)19 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)14 Layer (org.deeplearning4j.nn.api.Layer)14 Random (java.util.Random)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)9