use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class SamplingDataSetIterator method next.
@Override
public DataSet next() {
DataSet ret = sampleFrom.sample(batchSize);
numTimesSampled += batchSize;
return ret;
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class DataSetLossCalculator method calculateScore.
@Override
public double calculateScore(MultiLayerNetwork network) {
dataSetIterator.reset();
double lossSum = 0.0;
int exCount = 0;
while (dataSetIterator.hasNext()) {
DataSet dataSet = dataSetIterator.next();
if (dataSet == null)
break;
int nEx = dataSet.getFeatureMatrix().size(0);
lossSum += network.score(dataSet) * nEx;
exCount += nEx;
}
if (average)
return lossSum / exCount;
else
return lossSum;
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class BaseDatasetIterator method next.
@Override
public DataSet next(int num) {
fetcher.fetch(num);
DataSet next = fetcher.next();
if (preProcessor != null)
preProcessor.preProcess(next);
return next;
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class ExistingDataSetIterator method next.
@Override
public DataSet next() {
if (preProcessor != null) {
DataSet ds = iterator.next();
if (!ds.isPreProcessed()) {
preProcessor.preProcess(ds);
ds.markAsPreProcessed();
}
return ds;
} else
return iterator.next();
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class DropoutLayerTest method testDropoutLayerWithConvMnist.
@Test
public void testDropoutLayerWithConvMnist() throws Exception {
DataSetIterator iter = new MnistDataSetIterator(2, 2);
DataSet next = iter.next();
// Run without separate activation layer
MultiLayerConfiguration confIntegrated = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).dropOut(0.25).nOut(10).build()).backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork netIntegrated = new MultiLayerNetwork(confIntegrated);
netIntegrated.init();
netIntegrated.fit(next);
// Run with separate activation layer
MultiLayerConfiguration confSeparate = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).seed(123).list().layer(0, new ConvolutionLayer.Builder(4, 4).stride(2, 2).nIn(1).nOut(20).activation(Activation.RELU).weightInit(WeightInit.XAVIER).build()).layer(1, new DropoutLayer.Builder(0.25).build()).layer(2, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).nOut(10).build()).backprop(true).pretrain(false).setInputType(InputType.convolutionalFlat(28, 28, 1)).build();
MultiLayerNetwork netSeparate = new MultiLayerNetwork(confSeparate);
netSeparate.init();
netSeparate.fit(next);
// check parameters
assertEquals(netIntegrated.getLayer(0).getParam("W"), netSeparate.getLayer(0).getParam("W"));
assertEquals(netIntegrated.getLayer(0).getParam("b"), netSeparate.getLayer(0).getParam("b"));
assertEquals(netIntegrated.getLayer(1).getParam("W"), netSeparate.getLayer(2).getParam("W"));
assertEquals(netIntegrated.getLayer(1).getParam("b"), netSeparate.getLayer(2).getParam("b"));
// check activations
netIntegrated.setInput(next.getFeatureMatrix());
netSeparate.setInput(next.getFeatureMatrix());
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainIntegrated = netIntegrated.feedForward(true);
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTrainSeparate = netSeparate.feedForward(true);
assertEquals(actTrainIntegrated.get(1), actTrainSeparate.get(1));
assertEquals(actTrainIntegrated.get(2), actTrainSeparate.get(3));
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTestIntegrated = netIntegrated.feedForward(false);
Nd4j.getRandom().setSeed(12345);
List<INDArray> actTestSeparate = netSeparate.feedForward(false);
assertEquals(actTestIntegrated.get(1), actTrainSeparate.get(1));
assertEquals(actTestIntegrated.get(2), actTestSeparate.get(3));
}
Aggregations