use of org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator in project deeplearning4j by deeplearning4j.
the class DataSetIteratorTest method testBatchSizeOfOneIris.
@Test
public void testBatchSizeOfOneIris() throws Exception {
//Test for (a) iterators returning correct number of examples, and
//(b) Labels are a proper one-hot vector (i.e., sum is 1.0)
//Iris:
DataSetIterator iris = new IrisDataSetIterator(1, 5);
int irisC = 0;
while (iris.hasNext()) {
irisC++;
DataSet ds = iris.next();
assertTrue(ds.getLabels().sum(Integer.MAX_VALUE).getDouble(0) == 1.0);
}
assertEquals(5, irisC);
}
use of org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator in project deeplearning4j by deeplearning4j.
the class ROCTest method RocEvalSanityCheck.
@Test
public void RocEvalSanityCheck() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
iter.setPreProcessor(ns);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROCMultiClass roc = net.evaluateROCMultiClass(iter, 32);
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
ROCMultiClass manual = new ROCMultiClass(32);
manual.eval(l, out);
for (int i = 0; i < 3; i++) {
assertEquals(manual.calculateAUC(i), roc.calculateAUC(i), 1e-6);
double[][] rocCurve = roc.getResultsAsArray(i);
double[][] rocManual = manual.getResultsAsArray(i);
assertArrayEquals(rocCurve[0], rocManual[0], 1e-6);
assertArrayEquals(rocCurve[1], rocManual[1], 1e-6);
}
}
use of org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator in project deeplearning4j by deeplearning4j.
the class EvaluationToolsTests method testRocHtml.
@Test
public void testRocHtml() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
INDArray newLabels = Nd4j.create(150, 2);
newLabels.getColumn(0).assign(ds.getLabels().getColumn(0));
newLabels.getColumn(0).addi(ds.getLabels().getColumn(1));
newLabels.getColumn(1).assign(ds.getLabels().getColumn(2));
ds.setLabels(newLabels);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROC roc = new ROC(20);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc);
// System.out.println(str);
}
use of org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method testEarlyStoppingIris.
@Test
public void testEarlyStoppingIris() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().optimizationAlgo(OptimizationAlgorithm.STOCHASTIC_GRADIENT_DESCENT).iterations(1).updater(Updater.SGD).weightInit(WeightInit.XAVIER).list().layer(0, new OutputLayer.Builder().nIn(4).nOut(3).lossFunction(LossFunctions.LossFunction.MCXENT).build()).pretrain(false).backprop(true).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.setListeners(new ScoreIterationListener(1));
JavaRDD<DataSet> irisData = getIris();
EarlyStoppingModelSaver<MultiLayerNetwork> saver = new InMemoryModelSaver<>();
EarlyStoppingConfiguration<MultiLayerNetwork> esConf = new EarlyStoppingConfiguration.Builder<MultiLayerNetwork>().epochTerminationConditions(new MaxEpochsTerminationCondition(5)).iterationTerminationConditions(new MaxTimeIterationTerminationCondition(1, TimeUnit.MINUTES)).scoreCalculator(new SparkDataSetLossCalculator(irisData, true, sc.sc())).modelSaver(saver).build();
IEarlyStoppingTrainer<MultiLayerNetwork> trainer = new SparkEarlyStoppingTrainer(getContext().sc(), new ParameterAveragingTrainingMaster.Builder(irisBatchSize()).saveUpdater(true).averagingFrequency(1).build(), esConf, net, irisData);
EarlyStoppingResult<MultiLayerNetwork> result = trainer.fit();
System.out.println(result);
assertEquals(5, result.getTotalEpochs());
assertEquals(EarlyStoppingResult.TerminationReason.EpochTerminationCondition, result.getTerminationReason());
Map<Integer, Double> scoreVsIter = result.getScoreVsEpoch();
assertEquals(5, scoreVsIter.size());
String expDetails = esConf.getEpochTerminationConditions().get(0).toString();
assertEquals(expDetails, result.getTerminationDetails());
MultiLayerNetwork out = result.getBestModel();
assertNotNull(out);
//Check that best score actually matches (returned model vs. manually calculated score)
MultiLayerNetwork bestNetwork = result.getBestModel();
double score = bestNetwork.score(new IrisDataSetIterator(150, 150).next());
double bestModelScore = result.getBestModelScore();
assertEquals(bestModelScore, score, 1e-3);
}
use of org.deeplearning4j.datasets.iterator.impl.IrisDataSetIterator in project deeplearning4j by deeplearning4j.
the class TestEarlyStoppingSpark method getIris.
private JavaRDD<DataSet> getIris() {
JavaSparkContext sc = getContext();
IrisDataSetIterator iter = new IrisDataSetIterator(irisBatchSize(), 150);
List<DataSet> list = new ArrayList<>(150);
while (iter.hasNext()) list.add(iter.next());
return sc.parallelize(list);
}
Aggregations