use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingCompGraph method testEarlyStoppingIris.
@Test
public void testEarlyStoppingIris() {
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new DataSetLossCalculatorCG(irisIter, true)).modelSaver(saver).build();
IEarlyStoppingTrainer<ComputationGraph> trainer = new EarlyStoppingGraphTrainer(esConf, net, irisIter);
EarlyStoppingResult<ComputationGraph> result = trainer.fit();
System.out.println(result);
assertEquals(5, result.getTotalEpochs());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch();
assertEquals(5, scoreVsIter.size());
String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
assertEquals(expDetails, result.getTerminationDetails());
ComputationGraph out = result.getBestModel();
assertNotNull(out);
//Check that best score actually matches (returned model vs. manually calculated score)
ComputationGraph bestNetwork = result.getBestModel();
irisIter.reset();
double score = bestNetwork.score(irisIter.next());
assertEquals(result.getBestModelScore(), score, 1e-2);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingCompGraph method testBadTuning.
@Test
public void testBadTuning() {
//Test poor tuning (high LR): should terminate on MaxScoreIterationTerminationCondition
Nd4j.getRandom().setSeed(12345);
ComputationGraphConfiguration conf = new NeuralNetConfiguration.Builder().seed(12345).optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).learningRate(//Intentionally huge LR
5.0).weightInit(WeightInit.XAVIER).graphBuilder().addInputs("in").addLayer("0", new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build(), "in").setOutputs("0").pretrain(false).backprop(true).build();
ComputationGraph net = new ComputationGraph(conf);
net.setListeners(new ScoreIterationListener(1));
DataSetIterator irisIter = new IrisDataSetIterator(150, 150);
EarlyStoppingModelSaver<ComputationGraph> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<ComputationGraph> esConf = new EarlyStoppingConfiguration.Builder<ComputationGraph>().epochTerminationConditions(new MaxEpochsTerminationCondition(5000)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES), //Initial score is ~2.5
new MaxScoreIterationTerminationCondition(10)).scoreCalculator(new DataSetLossCalculatorCG(irisIter, true)).modelSaver(saver).build();
IEarlyStoppingTrainer trainer = new EarlyStoppingGraphTrainer(esConf, net, irisIter);
EarlyStoppingResult result = trainer.fit();
assertTrue(result.getTotalEpochs() < 5);
assertEquals(EarlyStoppingResult.TerminationReason.IterationTerminationCondition, result.getTerminationReason());
String expDetails = new MaxScoreIterationTerminationCondition(10).toString();
assertEquals(expDetails, result.getTerminationDetails());
assertEquals(0, result.getBestModelEpoch());
assertNotNull(result.getBestModel());
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class EvalTest method testIris.
@Test
public void testIris() {
// Network config
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.LINE_GRADIENT_DESCENT).iterations(1).seed(42).learningRate(1e-6).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(2).activation(Activation.TANH).weightInit(WeightInit.XAVIER).build()).layer(1, new org.deeplearning4j.nn.conf.layers.OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(2).nOut(3).weightInit(WeightInit.XAVIER).activation(Activation.SOFTMAX).build()).build();
// Instantiate model
MultiLayerNetwork model = new MultiLayerNetwork(conf);
model.init();
model.setListeners(Arrays.asList((IterationListener) new ScoreIterationListener(1)));
// Train-test split
DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet next = iter.next();
next.shuffle();
SplitTestAndTrain trainTest = next.splitTestAndTrain(5, new Random(42));
// Train
DataSet train = trainTest.getTrain();
train.normalizeZeroMeanZeroUnitVariance();
// Test
DataSet test = trainTest.getTest();
test.normalizeZeroMeanZeroUnitVariance();
INDArray testFeature = test.getFeatureMatrix();
INDArray testLabel = test.getLabels();
// Fitting model
model.fit(train);
// Get predictions from test feature
INDArray testPredictedLabel = model.output(testFeature);
// Eval with class number
//// Specify class num here
Evaluation eval = new Evaluation(3);
eval.eval(testLabel, testPredictedLabel);
double eval1F1 = eval.f1();
double eval1Acc = eval.accuracy();
// Eval without class number
//// No class num
Evaluation eval2 = new Evaluation();
eval2.eval(testLabel, testPredictedLabel);
double eval2F1 = eval2.f1();
double eval2Acc = eval2.accuracy();
//Assert the two implementations give same f1 and accuracy (since one batch)
assertTrue(eval1F1 == eval2F1 && eval1Acc == eval2Acc);
Evaluation evalViaMethod = model.evaluate(new ListDataSetIterator(Collections.singletonList(test)));
checkEvaluationEquality(eval, evalViaMethod);
System.out.println(eval.getConfusionMatrix().toString());
System.out.println(eval.getConfusionMatrix().toCSV());
System.out.println(eval.getConfusionMatrix().toHTML());
System.out.println(eval.confusionToString());
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class EvaluationToolsTests method testRocMultiToHtml.
@Test
public void testRocMultiToHtml() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROCMultiClass roc = new ROCMultiClass(20);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
// System.out.println(str);
}
use of org.nd4j.linalg.dataset.api.iterator.DataSetIterator in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetiteratorTest method testCSVLoadingRegression.
@Test
public void testCSVLoadingRegression() throws Exception {
int nLines = 30;
int nFeatures = 5;
int miniBatchSize = 10;
int labelIdx = 0;
String path = FilenameUtils.concat(System.getProperty("java.io.tmpdir"), "rr_csv_test_rand.csv");
double[][] data = makeRandomCSV(path, nLines, nFeatures);
RecordReader testReader = new CSVRecordReader();
testReader.initialize(new FileSplit(new File(path)));
DataSetIterator iter = new RecordReaderDataSetIterator(testReader, null, miniBatchSize, labelIdx, 1, true);
int miniBatch = 0;
while (iter.hasNext()) {
DataSet test = iter.next();
INDArray features = test.getFeatureMatrix();
INDArray labels = test.getLabels();
assertArrayEquals(new int[] { miniBatchSize, nFeatures }, features.shape());
assertArrayEquals(new int[] { miniBatchSize, 1 }, labels.shape());
int startRow = miniBatch * miniBatchSize;
for (int i = 0; i < miniBatchSize; i++) {
double labelExp = data[startRow + i][labelIdx];
double labelAct = labels.getDouble(i);
assertEquals(labelExp, labelAct, 1e-5f);
int featureCount = 0;
for (int j = 0; j < nFeatures + 1; j++) {
if (j == labelIdx)
continue;
double featureExp = data[startRow + i][j];
double featureAct = features.getDouble(i, featureCount++);
assertEquals(featureExp, featureAct, 1e-5f);
}
}
miniBatch++;
}
assertEquals(nLines / miniBatchSize, miniBatch);
}
Aggregations