use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class EvaluationToolsTests method testRocHtml.
@Test
public void testRocHtml() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(2).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
INDArray newLabels = Nd4j.create(150, 2);
newLabels.getColumn(0).assign(ds.getLabels().getColumn(0));
newLabels.getColumn(0).addi(ds.getLabels().getColumn(1));
newLabels.getColumn(1).assign(ds.getLabels().getColumn(2));
ds.setLabels(newLabels);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROC roc = new ROC(20);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc);
// System.out.println(str);
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class EvaluationToolsTests method testRocMultiToHtml.
@Test
public void testRocMultiToHtml() throws Exception {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROCMultiClass roc = new ROCMultiClass(20);
iter.reset();
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
roc.eval(l, out);
String str = EvaluationTools.rocChartToHtml(roc, Arrays.asList("setosa", "versicolor", "virginica"));
// System.out.println(str);
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class TestRecordReaders method testClassIndexOutsideOfRangeRRMDSI.
@Test
public void testClassIndexOutsideOfRangeRRMDSI() {
Collection<Collection<Collection<Writable>>> c = new ArrayList<>();
Collection<Collection<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(1)));
c.add(seq1);
Collection<Collection<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
c.add(seq2);
CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c);
DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, 2, 2, 1);
try {
DataSet ds = dsi.next();
fail("Expected exception");
} catch (DL4JException e) {
System.out.println("testClassIndexOutsideOfRangeRRMDSI(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail();
}
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class ComputationGraph method fit.
/**
* Fit the ComputationGraph using a DataSetIterator.
* Note that this method can only be used with ComputationGraphs with 1 input and 1 output
*/
public void fit(DataSetIterator iterator) {
if (flattenedGradients == null)
initGradientsView();
if (numInputArrays != 1 || numOutputArrays != 1)
throw new UnsupportedOperationException("Cannot train ComputationGraph network with " + " multiple inputs or outputs using a DataSetIterator");
DataSetIterator dataSetIterator;
// we're wrapping all iterators into AsyncDataSetIterator to provide background prefetch - where appropriate
if (iterator.asyncSupported()) {
dataSetIterator = new AsyncDataSetIterator(iterator, 2);
} else
dataSetIterator = iterator;
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochStart(this);
}
}
if (configuration.isPretrain()) {
pretrain(dataSetIterator);
}
if (configuration.isBackprop()) {
update(TaskUtils.buildTask(dataSetIterator));
while (dataSetIterator.hasNext()) {
DataSet next = dataSetIterator.next();
if (next.getFeatures() == null || next.getLabels() == null)
break;
boolean hasMaskArrays = next.hasMaskArrays();
if (hasMaskArrays) {
INDArray[] fMask = (next.getFeaturesMaskArray() != null ? new INDArray[] { next.getFeaturesMaskArray() } : null);
INDArray[] lMask = (next.getLabelsMaskArray() != null ? new INDArray[] { next.getLabelsMaskArray() } : null);
setLayerMaskArrays(fMask, lMask);
}
if (configuration.getBackpropType() == BackpropType.TruncatedBPTT) {
doTruncatedBPTT(new INDArray[] { next.getFeatures() }, new INDArray[] { next.getLabels() }, (hasMaskArrays ? new INDArray[] { next.getFeaturesMaskArray() } : null), (hasMaskArrays ? new INDArray[] { next.getLabelsMaskArray() } : null));
} else {
setInput(0, next.getFeatures());
setLabel(0, next.getLabels());
if (solver == null) {
solver = //TODO; don't like this
new Solver.Builder().configure(defaultConfiguration).listeners(listeners).model(this).build();
}
solver.optimize();
}
if (hasMaskArrays) {
clearLayerMaskArrays();
}
Nd4j.getMemoryManager().invokeGcOccasionally();
}
}
if (trainingListeners.size() > 0) {
for (TrainingListener tl : trainingListeners) {
tl.onEpochEnd(this);
}
}
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class TestMiscFunctions method testFeedForwardWithKey.
@Test
public void testFeedForwardWithKey() {
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(3).build()).layer(1, new OutputLayer.Builder(LossFunctions.LossFunction.MCXENT).nIn(3).nOut(3).activation(Activation.SOFTMAX).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
DataSetIterator iter = new IrisDataSetIterator(150, 150);
DataSet ds = iter.next();
List<INDArray> expected = new ArrayList<>();
List<Tuple2<Integer, INDArray>> mapFeatures = new ArrayList<>();
int count = 0;
int arrayCount = 0;
Random r = new Random(12345);
while (count < 150) {
//1 to 5 inclusive examples
int exampleCount = r.nextInt(5) + 1;
if (count + exampleCount > 150)
exampleCount = 150 - count;
INDArray subset = ds.getFeatures().get(NDArrayIndex.interval(count, count + exampleCount), NDArrayIndex.all());
expected.add(net.output(subset, false));
mapFeatures.add(new Tuple2<>(arrayCount, subset));
arrayCount++;
count += exampleCount;
}
JavaPairRDD<Integer, INDArray> rdd = sc.parallelizePairs(mapFeatures);
SparkDl4jMultiLayer multiLayer = new SparkDl4jMultiLayer(sc, net, null);
Map<Integer, INDArray> map = multiLayer.feedForwardWithKey(rdd, 16).collectAsMap();
for (int i = 0; i < expected.size(); i++) {
INDArray exp = expected.get(i);
INDArray act = map.get(i);
assertEquals(exp, act);
}
}
Aggregations