Search in sources :

Example 11 with DataSet

use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.

the class TestMiscFunctions method testFeedForwardWithKeyGraph.

@Test
public void testFeedForwardWithKeyGraph() {
    ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in1", "in2").addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in1").addLayer("1", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in2").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(6).nOut(3).activation(Activation.SOFTMAX).build(), "0", "1").setOutputs("2").build();
    ComputationGraph net = new ComputationGraph(conf);
    net.init();
    DataSetIterator iter = new IrisDataSetIterator(150, 150);
    DataSet ds = iter.next();
    List<INDArray> expected = new ArrayList<>();
    List<Tuple2<Integer, INDArray[]>> mapFeatures = new ArrayList<>();
    int count = 0;
    int arrayCount = 0;
    Random r = new Random(12345);
    while (count < 150) {
        //1 to 5 inclusive examples
        int exampleCount = r.nextInt(5) + 1;
        if (count + exampleCount > 150)
            exampleCount = 150 - count;
        INDArray subset = ds.getFeatures().get(NDArrayIndex.interval(count, count + exampleCount), NDArrayIndex.all());
        expected.add(net.outputSingle(false, subset, subset));
        mapFeatures.add(new Tuple2<>(arrayCount, new INDArray[] { subset, subset }));
        arrayCount++;
        count += exampleCount;
    }
    JavaPairRDD<Integer, INDArray[]> rdd = sc.parallelizePairs(mapFeatures);
    SparkComputationGraph graph = new SparkComputationGraph(sc, net, null);
    Map<Integer, INDArray[]> map = graph.feedForwardWithKey(rdd, 16).collectAsMap();
    for (int i = 0; i < expected.size(); i++) {
        INDArray exp = expected.get(i);
        INDArray act = map.get(i)[0];
        assertEquals(exp, act);
    }
}
Also used : OutputLayer(org.deeplearning4j.nn.conf.layers.OutputLayer) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSet(org.nd4j.linalg.dataset.api.DataSet) DenseLayer(org.deeplearning4j.nn.conf.layers.DenseLayer) INDArray(org.nd4j.linalg.api.ndarray.INDArray) Tuple2(scala.Tuple2) ComputationGraphConfiguration(org.deeplearning4j.nn.conf.ComputationGraphConfiguration) ComputationGraph(org.deeplearning4j.nn.graph.ComputationGraph) SparkComputationGraph(org.deeplearning4j.spark.impl.graph.SparkComputationGraph) IrisDataSetIterator(org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator) DataSetIterator(org.nd4j.linalg.dataset.api.iterator.DataSetIterator) BaseSparkTest(org.deeplearning4j.spark.BaseSparkTest) Test(org.junit.Test)

Aggregations

DataSet (org.nd4j.linalg.dataset.api.DataSet)11 Test (org.junit.Test)9 DataSetIterator (org.nd4j.linalg.dataset.api.iterator.DataSetIterator)9 INDArray (org.nd4j.linalg.api.ndarray.INDArray)8 IrisDataSetIterator (org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator)5 MultiLayerNetwork (org.deeplearning4j.nn.multilayer.MultiLayerNetwork)5 ArrayList (java.util.ArrayList)4 MultiLayerConfiguration (org.deeplearning4j.nn.conf.MultiLayerConfiguration)4 NeuralNetConfiguration (org.deeplearning4j.nn.conf.NeuralNetConfiguration)4 DenseLayer (org.deeplearning4j.nn.conf.layers.DenseLayer)4 OutputLayer (org.deeplearning4j.nn.conf.layers.OutputLayer)4 Collection (java.util.Collection)3 DoubleWritable (org.datavec.api.writable.DoubleWritable)3 IntWritable (org.datavec.api.writable.IntWritable)3 RecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.RecordReaderDataSetIterator)3 SequenceRecordReaderDataSetIterator (org.deeplearning4j.datasets.datavec.SequenceRecordReaderDataSetIterator)3 DL4JException (org.deeplearning4j.exception.DL4JException)3 NormalizerStandardize (org.nd4j.linalg.dataset.api.preprocessor.NormalizerStandardize)3 CollectionSequenceRecordReader (org.datavec.api.records.reader.impl.collection.CollectionSequenceRecordReader)2 Writable (org.datavec.api.writable.Writable)2