Search in sources :

Example 1 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class CenterLossOutputLayerTest method testMNISTConfig.

@Test
//Should be run manually
@Ignore
public void testMNISTConfig() throws Exception {
    // Test batch size
    int batchSize = 64;
    DataSetIterator mnistTrain = new MnistDataSetIterator(batchSize, true, 12345);
    ComputationGraph net = getCNNMnistConfig();
    net.init();
    net.setListeners(new ScoreIterationListener(1));
    for (int i = 0; i < 50; i++) {
        net.fit(mnistTrain.next());
        Thread.sleep(1000);
    }
    Thread.sleep(100000);
}
Also used : MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) ScoreIterationListener(org.deeplearning4j.optimize.listeners.ScoreIterationListener) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) MnistDataSetIterator(org.deeplearning4j.datasets.iterator.impl.MnistDataSetIterator) Ignore(org.junit.Ignore) Test(org.junit.Test)

Example 2 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class FrozenLayerTest method cloneCompGraphFrozen.

@Test
public void cloneCompGraphFrozen() {
    DataSet randomData = new DataSet(Nd4j.rand(10, 4), Nd4j.rand(10, 3));
    NeuralNetConfiguration.Builder overallConf = new NeuralNetConfiguration.Builder().learningRate(0.1).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).updater(Updater.SGD).activation(Activation.IDENTITY);
    ComputationGraph modelToFineTune = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "layer0In").addLayer("layer1", new DenseLayer.Builder().nIn(3).nOut(2).build(), "layer0").addLayer("layer2", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer1").addLayer("layer3", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer2").setOutputs("layer3").build());
    modelToFineTune.init();
    INDArray asFrozenFeatures = modelToFineTune.feedForward(randomData.getFeatures(), false).get("layer1");
    ComputationGraph modelNow = new TransferLearning.GraphBuilder(modelToFineTune).setFeatureExtractor("layer1").build();
    ComputationGraph clonedModel = modelNow.clone();
    //Check json
    assertEquals(clonedModel.getConfiguration().toJson(), modelNow.getConfiguration().toJson());
    //Check params
    assertEquals(modelNow.params(), clonedModel.params());
    ComputationGraph notFrozen = new ComputationGraph(overallConf.graphBuilder().addInputs("layer0In").addLayer("layer0", new DenseLayer.Builder().nIn(2).nOut(3).build(), "layer0In").addLayer("layer1", new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).activation(Activation.SOFTMAX).nIn(3).nOut(3).build(), "layer0").setOutputs("layer1").build());
    notFrozen.init();
    notFrozen.setParams(Nd4j.hstack(modelToFineTune.getLayer("layer2").params(), modelToFineTune.getLayer("layer3").params()));
    int i = 0;
    while (i < 5) {
        notFrozen.fit(new DataSet(asFrozenFeatures, randomData.getLabels()));
        modelNow.fit(randomData);
        clonedModel.fit(randomData);
        i++;
    }
    INDArray expectedParams = Nd4j.hstack(modelToFineTune.getLayer("layer0").params(), modelToFineTune.getLayer("layer1").params(), notFrozen.params());
    assertEquals(expectedParams, modelNow.params());
    assertEquals(expectedParams, clonedModel.params());
}
Also used : DataSet(org.nd4j.linalg.dataset.DataSet) NeuralNetConfiguration(org.deeplearning4j.nn.conf.NeuralNetConfiguration) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 3 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class TestGraphNodes method testDuplicateToTimeSeriesVertex.

@Test
public void testDuplicateToTimeSeriesVertex() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().graphBuilder().addInputs("in2d", "in3d").addVertex("duplicateTS", new DuplicateToTimeSeriesVertex("in3d"), "in2d").addLayer("out", new OutputLayer.Builder().nIn(1).nOut(1).build(), "duplicateTS").setOutputs("out").build();
    ComputationGraph graph = new ComputationGraph(conf);
    graph.init();
    INDArray in2d = Nd4j.rand(3, 5);
    INDArray in3d = Nd4j.rand(new int[] { 3, 2, 7 });
    graph.setInputs(in2d, in3d);
    INDArray expOut = Nd4j.zeros(3, 5, 7);
    for (int i = 0; i < 7; i++) {
        expOut.put(new INDArrayIndex[] { NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.point(i) }, in2d);
    }
    GraphVertex gv = graph.getVertex("duplicateTS");
    gv.setInputs(in2d);
    INDArray outFwd = gv.doForward(true);
    assertEquals(expOut, outFwd);
    INDArray expOutBackward = expOut.sum(2);
    gv.setEpsilon(expOut);
    INDArray outBwd = gv.doBackward(false).getSecond()[0];
    assertEquals(expOutBackward, outBwd);
    String json = conf.toJson();
    ComputationGraphConfiguration conf2 = ComputationGraphConfiguration.fromJson(json);
    assertEquals(conf, conf2);
}
Also used : GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) DuplicateToTimeSeriesVertex(org.deeplearning4j.nn.conf.graph.rnn.DuplicateToTimeSeriesVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) Test(org.junit.Test)

Example 4 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class RemoteFlowIterationListener method buildModelInfo.

protected ModelInfo buildModelInfo(Model model) {
    ModelInfo modelInfo = new ModelInfo();
    if (model instanceof ComputationGraph) {
        ComputationGraph graph = (ComputationGraph) model;
        /*
                we assume that graph starts on input. every layer connected to input - is on y1
                every layer connected to y1, is on y2 etc.
              */
        List<String> inputs = graph.getConfiguration().getNetworkInputs();
        // now we need to add inputs as y0 nodes
        int x = 0;
        for (String input : inputs) {
            GraphVertex vertex = graph.getVertex(input);
            INDArray gInput = vertex.getInputs()[0];
            long tadLength = Shape.getTADLength(gInput.shape(), ArrayUtil.range(1, gInput.rank()));
            long numSamples = gInput.lengthLong() / tadLength;
            StringBuilder builder = new StringBuilder();
            builder.append("Vertex name: ").append(input).append("<br/>");
            builder.append("Model input").append("<br/>");
            builder.append("Input size: ").append(tadLength).append("<br/>");
            builder.append("Batch size: ").append(numSamples).append("<br/>");
            LayerInfo info = new LayerInfo();
            info.setId(0);
            info.setName(input);
            info.setY(0);
            info.setX(x);
            info.setLayerType(INPUT);
            info.setDescription(new Description());
            info.getDescription().setMainLine("Model input");
            info.getDescription().setText(builder.toString());
            modelInfo.addLayer(info);
            x++;
        }
        GraphVertex[] vertices = graph.getVertices();
        // filling grid in LTR/TTB direction
        List<String> needle = new ArrayList<>();
        // we assume that max row can't be higher then total number of vertices
        for (int y = 1; y < vertices.length; y++) {
            if (needle.isEmpty())
                needle.addAll(inputs);
            /*
                    for each grid row we look for nodes, that are connected to previous layer
                */
            List<LayerInfo> layersForGridY = flattenToY(modelInfo, vertices, needle, y);
            needle.clear();
            for (LayerInfo layerInfo : layersForGridY) {
                needle.add(layerInfo.getName());
            }
            if (needle.isEmpty())
                break;
        }
    } else if (model instanceof MultiLayerNetwork) {
        MultiLayerNetwork network = (MultiLayerNetwork) model;
        // manually adding input layer
        INDArray input = model.input();
        long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
        long numSamples = input.lengthLong() / tadLength;
        StringBuilder builder = new StringBuilder();
        builder.append("Model input").append("<br/>");
        builder.append("Input size: ").append(tadLength).append("<br/>");
        builder.append("Batch size: ").append(numSamples).append("<br/>");
        LayerInfo info = new LayerInfo();
        info.setId(0);
        info.setName("Input");
        info.setY(0);
        info.setX(0);
        info.setLayerType(INPUT);
        info.setDescription(new Description());
        info.getDescription().setMainLine("Model input");
        info.getDescription().setText(builder.toString());
        info.addConnection(0, 1);
        modelInfo.addLayer(info);
        // entry 0 is reserved for inputs
        int y = 1;
        // for MLN x value is always 0
        final int x = 0;
        for (Layer layer : network.getLayers()) {
            LayerInfo layerInfo = getLayerInfo(layer, x, y, y);
            // since it's MLN, we know connections in advance as curLayer + 1
            layerInfo.addConnection(x, y + 1);
            modelInfo.addLayer(layerInfo);
            y++;
        }
        LayerInfo layerInfo = modelInfo.getLayerInfoByCoords(x, y - 1);
        layerInfo.dropConnections();
    }
    // find layers without connections, and mark them as output layers
    for (LayerInfo layerInfo : modelInfo.getLayers()) {
        if (layerInfo.getConnections().size() == 0)
            layerInfo.setLayerType("OUTPUT");
    }
    // now we apply colors to distinct layer types
    AtomicInteger cnt = new AtomicInteger(0);
    for (String layerType : modelInfo.getLayerTypes()) {
        String curColor = colors.get(cnt.getAndIncrement());
        if (cnt.get() >= colors.size())
            cnt.set(0);
        for (LayerInfo layerInfo : modelInfo.getLayersByType(layerType)) {
            if (layerType.equals(INPUT)) {
                layerInfo.setColor("#99ff66");
            } else if (layerType.equals("OUTPUT")) {
                layerInfo.setColor("#e6e6e6");
            } else {
                layerInfo.setColor(curColor);
            }
        }
    }
    return modelInfo;
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) BaseOutputLayer(org.deeplearning4j.nn.conf.layers.BaseOutputLayer) GraphVertex(org.deeplearning4j.nn.graph.vertex.GraphVertex) INDArray(org.nd4j.linalg.api.ndarray.INDArray) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Example 5 with ComputationGraph

use of org.deeplearning4j.nn.graph.ComputationGraph in project deeplearning4j by deeplearning4j.

the class RemoteFlowIterationListener method buildModelState.

protected void buildModelState(Model model) {
    // first we update performance state
    long timeSpent = currTime - lastTime;
    float timeSec = timeSpent / 1000f;
    INDArray input = model.input();
    long tadLength = Shape.getTADLength(input.shape(), ArrayUtil.range(1, input.rank()));
    long numSamples = input.lengthLong() / tadLength;
    modelState.addPerformanceSamples(numSamples / timeSec);
    modelState.addPerformanceBatches(1 / timeSec);
    modelState.setIterationTime(timeSpent);
    // now model score
    modelState.addScore((float) model.score());
    modelState.setScore((float) model.score());
    modelState.setTrainingTime(parseTime(System.currentTimeMillis() - initTime));
    // and now update model params/gradients
    Map<String, Map> newGrad = new LinkedHashMap<>();
    Map<String, Map> newParams = new LinkedHashMap<>();
    Map<String, INDArray> params = model.paramTable();
    Layer[] layers = null;
    if (model instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) model).getLayers();
    } else if (model instanceof ComputationGraph) {
        layers = ((ComputationGraph) model).getLayers();
    }
    List<Double> lrs = new ArrayList<>();
    if (layers != null) {
        for (Layer layer : layers) {
            lrs.add(layer.conf().getLayer().getLearningRate());
        }
        modelState.setLearningRates(lrs);
    }
    Map<Integer, LayerParams> layerParamsMap = new LinkedHashMap<>();
    for (Map.Entry<String, INDArray> entry : params.entrySet()) {
        String param = entry.getKey();
        if (!Character.isDigit(param.charAt(0)))
            continue;
        int layer = Integer.parseInt(param.replaceAll("\\_.*$", ""));
        String key = param.replaceAll("^.*?_", "").toLowerCase();
        if (!layerParamsMap.containsKey(layer))
            layerParamsMap.put(layer, new LayerParams());
        HistogramBin histogram = new HistogramBin.Builder(entry.getValue().dup()).setBinCount(14).setRounding(6).build();
        // TODO: something better would be nice to have here
        if (key.equalsIgnoreCase("w")) {
            layerParamsMap.get(layer).setW(histogram.getData());
        } else if (key.equalsIgnoreCase("rw")) {
            layerParamsMap.get(layer).setRW(histogram.getData());
        } else if (key.equalsIgnoreCase("rwf")) {
            layerParamsMap.get(layer).setRWF(histogram.getData());
        } else if (key.equalsIgnoreCase("b")) {
            layerParamsMap.get(layer).setB(histogram.getData());
        }
    }
    modelState.setLayerParams(layerParamsMap);
}
Also used : Layer(org.deeplearning4j.nn.api.Layer) SubsamplingLayer(org.deeplearning4j.nn.conf.layers.SubsamplingLayer) FeedForwardLayer(org.deeplearning4j.nn.conf.layers.FeedForwardLayer) BaseOutputLayer(org.deeplearning4j.nn.conf.layers.BaseOutputLayer) AtomicInteger(java.util.concurrent.atomic.AtomicInteger) INDArray(org.nd4j.linalg.api.ndarray.INDArray) HistogramBin(org.deeplearning4j.ui.weights.HistogramBin) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) MultiLayerNetwork(org.deeplearning4j.nn.multilayer.MultiLayerNetwork)

Aggregations

ComputationGraph (org.deeplearning4j.nn.graph.ComputationGraph)109 Test (org.junit.Test)73 ComputationGraphConfiguration (org.deeplearning4j.nn.conf.ComputationGraphConfiguration)63 INDArray (org.nd4j.linalg.api.ndarray.INDArray)62 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)36 DataSet (org.nd4j.linalg.dataset.DataSet)25 NormalDistribution (org.deeplearning4j.nn.conf.distribution.NormalDistribution)22 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)21 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)19 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)19 ScoreIterationListener (org.deeplearning4j.optimize.listeners.ScoreIterationListener)17 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)17 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)14 Layer (org.deeplearning4j.nn.api.Layer)14 Random (java.util.Random)11 InMemoryModelSaver (org.deeplearning4j.earlystopping.saver.InMemoryModelSaver)10 MaxEpochsTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxEpochsTerminationCondition)10 TrainingMaster (org.deeplearning4j.spark.api.TrainingMaster)10 MaxTimeIterationTerminationCondition (org.deeplearning4j.earlystopping.termination.MaxTimeIterationTerminationCondition)9 GridExecutioner (org.nd4j.linalg.api.ops.executioner.GridExecutioner)9