use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class TestCnnSentenceDataSetIterator method testSentenceIterator.
@Test
public void testSentenceIterator() throws Exception {
WordVectors w2v = WordVectorSerializer.readWord2VecModel(new ClassPathResource("word2vec/googleload/sample_vec.bin").getFile());
int vectorSize = w2v.lookupTable().layerSize();
// Collection<String> words = w2v.lookupTable().getVocabCache().words();
// for(String s : words){
// System.out.println(s);
// }
List<String> sentences = new ArrayList<>();
//First word: all present
sentences.add("these balance Database model");
sentences.add("into same THISWORDDOESNTEXIST are");
int maxLength = 4;
List<String> s1 = Arrays.asList("these", "balance", "Database", "model");
List<String> s2 = Arrays.asList("into", "same", "are");
List<String> labelsForSentences = Arrays.asList("Positive", "Negative");
//Order of labels: alphabetic. Positive -> [0,1]
INDArray expLabels = Nd4j.create(new double[][] { { 0, 1 }, { 1, 0 } });
boolean[] alongHeightVals = new boolean[] { true, false };
for (boolean alongHeight : alongHeightVals) {
INDArray expectedFeatures;
if (alongHeight) {
expectedFeatures = Nd4j.create(2, 1, maxLength, vectorSize);
} else {
expectedFeatures = Nd4j.create(2, 1, vectorSize, maxLength);
}
INDArray expectedFeatureMask = Nd4j.create(new double[][] { { 1, 1, 1, 1 }, { 1, 1, 1, 0 } });
for (int i = 0; i < 4; i++) {
if (alongHeight) {
expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.point(i), NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s1.get(i)));
} else {
expectedFeatures.get(NDArrayIndex.point(0), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s1.get(i)));
}
}
for (int i = 0; i < 3; i++) {
if (alongHeight) {
expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.point(i), NDArrayIndex.all()).assign(w2v.getWordVectorMatrix(s2.get(i)));
} else {
expectedFeatures.get(NDArrayIndex.point(1), NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.point(i)).assign(w2v.getWordVectorMatrix(s2.get(i)));
}
}
LabeledSentenceProvider p = new CollectionLabeledSentenceProvider(sentences, labelsForSentences, null);
CnnSentenceDataSetIterator dsi = new CnnSentenceDataSetIterator.Builder().sentenceProvider(p).wordVectors(w2v).maxSentenceLength(256).minibatchSize(32).sentencesAlongHeight(alongHeight).build();
// System.out.println("alongHeight = " + alongHeight);
DataSet ds = dsi.next();
assertArrayEquals(expectedFeatures.shape(), ds.getFeatures().shape());
assertEquals(expectedFeatures, ds.getFeatures());
assertEquals(expLabels, ds.getLabels());
assertEquals(expectedFeatureMask, ds.getFeaturesMaskArray());
assertNull(ds.getLabelsMaskArray());
INDArray s1F = dsi.loadSingleSentence(sentences.get(0));
INDArray s2F = dsi.loadSingleSentence(sentences.get(1));
INDArray sub1 = ds.getFeatures().get(NDArrayIndex.interval(0, 0, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.all());
INDArray sub2;
if (alongHeight) {
sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.interval(0, 3), NDArrayIndex.all());
} else {
sub2 = ds.getFeatures().get(NDArrayIndex.interval(1, 1, true), NDArrayIndex.all(), NDArrayIndex.all(), NDArrayIndex.interval(0, 3));
}
assertArrayEquals(sub1.shape(), s1F.shape());
assertArrayEquals(sub2.shape(), s2F.shape());
assertEquals(sub1, s1F);
assertEquals(sub2, s2F);
}
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
* This method takes DataSetIterator, and starts training over it by scheduling DataSets to different executors
*
* @param source
*/
public synchronized void fit(@NonNull DataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread());
// if if we're using MQ here - we'd like
if (isMQ)
Nd4j.getAffinityManager().attachThreadToDevice(zoo[cnt], cnt % Nd4j.getAffinityManager().getNumberOfDevices());
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
}
source.reset();
DataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
if (isMQ) {
if (workers % Nd4j.getAffinityManager().getNumberOfDevices() != 0)
log.warn("Number of workers [{}] isn't optimal for available devices [{}]", workers, Nd4j.getAffinityManager().getNumberOfDevices());
MagicQueue queue = new MagicQueue.Builder().setCapacityPerFlow(8).setMode(MagicQueue.Mode.SEQUENTIAL).setNumberOfBuckets(Nd4j.getAffinityManager().getNumberOfDevices()).build();
iterator = new AsyncDataSetIterator(source, prefetchSize, queue);
} else
iterator = new AsyncDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
int whiles = 0;
while (iterator.hasNext() && !stopFit.get()) {
whiles++;
DataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as DataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
if (zoo == null)
throw new IllegalStateException("ParallelWrapper.shutdown() has been called too early and will fail from this point forward.");
zoo[pos].feedDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof MultiLayerNetwork) {
if (averageUpdaters) {
Updater updater = ((MultiLayerNetwork) model).getUpdater();
int batchSize = 0;
if (updater != null && updater.getStateViewArray() != null) {
if (!legacyAveraging || Nd4j.getAffinityManager().getNumberOfDevices() == 1) {
List<INDArray> updaters = new ArrayList<>();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
updaters.add(workerModel.getUpdater().getStateViewArray());
batchSize += workerModel.batchSize();
}
Nd4j.averageAndPropagate(updater.getStateViewArray(), updaters);
} else {
INDArray state = Nd4j.zeros(updater.getStateViewArray().shape());
int cnt = 0;
for (; cnt < workers && cnt < locker.get(); cnt++) {
MultiLayerNetwork workerModel = (MultiLayerNetwork) zoo[cnt].getModel();
state.addi(workerModel.getUpdater().getStateViewArray().dup());
batchSize += workerModel.batchSize();
}
state.divi(cnt);
updater.setStateViewArray((MultiLayerNetwork) model, state, false);
}
}
}
((MultiLayerNetwork) model).setScore(score);
} else if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
}
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class TestRecordReaders method testClassIndexOutsideOfRangeRRMDSI_MultipleReaders.
@Test
public void testClassIndexOutsideOfRangeRRMDSI_MultipleReaders() {
Collection<Collection<Collection<Writable>>> c1 = new ArrayList<>();
Collection<Collection<Writable>> seq1 = new ArrayList<>();
seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
seq1.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
c1.add(seq1);
Collection<Collection<Writable>> seq2 = new ArrayList<>();
seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
seq2.add(Arrays.<Writable>asList(new DoubleWritable(0.0)));
c1.add(seq2);
Collection<Collection<Collection<Writable>>> c2 = new ArrayList<>();
Collection<Collection<Writable>> seq1a = new ArrayList<>();
seq1a.add(Arrays.<Writable>asList(new IntWritable(0)));
seq1a.add(Arrays.<Writable>asList(new IntWritable(1)));
c2.add(seq1a);
Collection<Collection<Writable>> seq2a = new ArrayList<>();
seq2a.add(Arrays.<Writable>asList(new IntWritable(0)));
seq2a.add(Arrays.<Writable>asList(new IntWritable(2)));
c2.add(seq2a);
CollectionSequenceRecordReader csrr = new CollectionSequenceRecordReader(c1);
CollectionSequenceRecordReader csrrLabels = new CollectionSequenceRecordReader(c2);
DataSetIterator dsi = new SequenceRecordReaderDataSetIterator(csrr, csrrLabels, 2, 2);
try {
DataSet ds = dsi.next();
fail("Expected exception");
} catch (DL4JException e) {
System.out.println("testClassIndexOutsideOfRangeRRMDSI_MultipleReaders(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail();
}
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class TestRecordReaders method testClassIndexOutsideOfRangeRRDSI.
@Test
public void testClassIndexOutsideOfRangeRRDSI() {
Collection<Collection<Writable>> c = new ArrayList<>();
c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(0)));
c.add(Arrays.<Writable>asList(new DoubleWritable(0.0), new IntWritable(2)));
CollectionRecordReader crr = new CollectionRecordReader(c);
RecordReaderDataSetIterator iter = new RecordReaderDataSetIterator(crr, 2, 1, 2);
try {
DataSet ds = iter.next();
fail("Expected exception");
} catch (DL4JException e) {
System.out.println("testClassIndexOutsideOfRange(): " + e.getMessage());
} catch (Exception e) {
e.printStackTrace();
fail();
}
}
use of org.nd4j.linalg.dataset.api.DataSet in project deeplearning4j by deeplearning4j.
the class ROCTest method RocEvalSanityCheck.
@Test
public void RocEvalSanityCheck() {
DataSetIterator iter = new IrisDataSetIterator(150, 150);
MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder().weightInit(WeightInit.XAVIER).list().layer(0, new DenseLayer.Builder().nIn(4).nOut(4).activation(Activation.TANH).build()).layer(1, new OutputLayer.Builder().nIn(4).nOut(3).activation(Activation.SOFTMAX).lossFunction(LossFunctions.LossFunction.MCXENT).build()).build();
MultiLayerNetwork net = new MultiLayerNetwork(conf);
net.init();
NormalizerStandardize ns = new NormalizerStandardize();
DataSet ds = iter.next();
ns.fit(ds);
ns.transform(ds);
iter.setPreProcessor(ns);
for (int i = 0; i < 30; i++) {
net.fit(ds);
}
ROCMultiClass roc = net.evaluateROCMultiClass(iter, 32);
INDArray f = ds.getFeatures();
INDArray l = ds.getLabels();
INDArray out = net.output(f);
ROCMultiClass manual = new ROCMultiClass(32);
manual.eval(l, out);
for (int i = 0; i < 3; i++) {
assertEquals(manual.calculateAUC(i), roc.calculateAUC(i), 1e-6);
double[][] rocCurve = roc.getResultsAsArray(i);
double[][] rocManual = manual.getResultsAsArray(i);
assertArrayEquals(rocCurve[0], rocManual[0], 1e-6);
assertArrayEquals(rocCurve[1], rocManual[1], 1e-6);
}
}
Aggregations