use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class TestMiscFunctions method testFeedForwardWithKeyGraph.
@Test
public void testFeedForwardWithKeyGraph() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in1", "in2").addLayer("0", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in1").addLayer("1", new DenseLayer.Builder().nIn(4).nOut(3).build(), "in2").addLayer("2", new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(6).nOut(3).activation(Activation.SOFTMAX).build(), "0", "1").setOutputs("2").build();
ComputationGraph net = new ComputationGraph(conf);
net.init();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet ds = iter.next();
List<INDArray> expected = new ArrayList<>();
List<Tuple2<Integer, INDArray[]>> mapFeatures = new ArrayList<>();
int count = 0;
int arrayCount = 0;
Random r = new Random(12345);
while (count < 150) {
//1 to 5 inclusive examples
int exampleCount = r.nextInt(5) + 1;
if (count + exampleCount > 150)
exampleCount = 150 - count;
INDArray subset = ds.getFeatures().get(NDArrayIndex.interval(count, count + exampleCount), NDArrayIndex.all());
expected.add(net.outputSingle(false, subset, subset));
mapFeatures.add(new Tuple2<>(arrayCount, new INDArray[] { subset, subset }));
arrayCount++;
count += exampleCount;
}
JavaPairRDD<Integer, INDArray[]> rdd = sc.parallelizePairs(mapFeatures);
SparkComputationGraph graph = new SparkComputationGraph(sc, net, null);
Map<Integer, INDArray[]> map = graph.feedForwardWithKey(rdd, 16).collectAsMap();
for (int i = 0; i < expected.size(); i++) {
INDArray exp = expected.get(i);
INDArray act = map.get(i)[0];
assertEquals(exp, act);
}
}
Aggregations