use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class IteratorMultiDataSetIterator method next.
@Override
public MultiDataSet next(int num) {
if (!hasNext())
throw new NoSuchElementException();
List<MultiDataSet> list = new ArrayList<>();
int countSoFar = 0;
while ((!queued.isEmpty() || iterator.hasNext()) && countSoFar < batchSize) {
MultiDataSet next;
if (!queued.isEmpty()) {
next = queued.removeFirst();
} else {
next = iterator.next();
}
int nExamples = next.getFeatures(0).size(0);
if (countSoFar + nExamples <= batchSize) {
//Add the entire MultiDataSet as-is
list.add(next);
} else {
//Split the MultiDataSet
int nFeatures = next.numFeatureArrays();
int nLabels = next.numLabelsArrays();
INDArray[] fToKeep = new INDArray[nFeatures];
INDArray[] lToKeep = new INDArray[nLabels];
INDArray[] fToCache = new INDArray[nFeatures];
INDArray[] lToCache = new INDArray[nLabels];
INDArray[] fMaskToKeep = (next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null);
INDArray[] lMaskToKeep = (next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null);
INDArray[] fMaskToCache = (next.getFeaturesMaskArrays() != null ? new INDArray[nFeatures] : null);
INDArray[] lMaskToCache = (next.getLabelsMaskArrays() != null ? new INDArray[nLabels] : null);
for (int i = 0; i < nFeatures; i++) {
INDArray fi = next.getFeatures(i);
fToKeep[i] = getRange(fi, 0, batchSize - countSoFar);
fToCache[i] = getRange(fi, batchSize - countSoFar, nExamples);
if (fMaskToKeep != null) {
INDArray fmi = next.getFeaturesMaskArray(i);
fMaskToKeep[i] = getRange(fmi, 0, batchSize - countSoFar);
fMaskToCache[i] = getRange(fmi, batchSize - countSoFar, nExamples);
}
}
for (int i = 0; i < nLabels; i++) {
INDArray li = next.getLabels(i);
lToKeep[i] = getRange(li, 0, batchSize - countSoFar);
lToCache[i] = getRange(li, batchSize - countSoFar, nExamples);
if (lMaskToKeep != null) {
INDArray lmi = next.getLabelsMaskArray(i);
lMaskToKeep[i] = getRange(lmi, 0, batchSize - countSoFar);
lMaskToCache[i] = getRange(lmi, batchSize - countSoFar, nExamples);
}
}
MultiDataSet toKeep = new org.nd4j.linalg.dataset.MultiDataSet(fToKeep, lToKeep, fMaskToKeep, lMaskToKeep);
MultiDataSet toCache = new org.nd4j.linalg.dataset.MultiDataSet(fToCache, lToCache, fMaskToCache, lMaskToCache);
list.add(toKeep);
queued.add(toCache);
}
countSoFar += nExamples;
}
MultiDataSet out;
if (list.size() == 1) {
out = list.get(0);
} else {
out = org.nd4j.linalg.dataset.MultiDataSet.merge(list);
}
if (preProcessor != null)
preProcessor.preProcess(out);
return out;
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class DataSetLossCalculatorCG method calculateScore.
@Override
public double calculateScore(ComputationGraph network) {
double lossSum = 0.0;
int exCount = 0;
if (dataSetIterator != null) {
dataSetIterator.reset();
while (dataSetIterator.hasNext()) {
DataSet dataSet = dataSetIterator.next();
int nEx = dataSet.getFeatureMatrix().size(0);
lossSum += network.score(dataSet) * nEx;
exCount += nEx;
}
} else {
multiDataSetIterator.reset();
while (multiDataSetIterator.hasNext()) {
MultiDataSet dataSet = multiDataSetIterator.next();
int nEx = dataSet.getFeatures(0).size(0);
lossSum += network.score(dataSet) * nEx;
exCount += nEx;
}
}
if (average)
return lossSum / exCount;
else
return lossSum;
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ComputationGraph method pretrainLayer.
/**
* Pretrain a specified layer with the given MultiDataSetIterator
*
* @param layerName Layer name
* @param iter Training data
*/
public void pretrainLayer(String layerName, MultiDataSetIterator iter) {
if (!configuration.isPretrain())
return;
if (flattenedGradients == null)
initGradientsView();
if (!verticesMap.containsKey(layerName)) {
throw new IllegalStateException("Invalid vertex name: " + layerName);
}
if (!verticesMap.get(layerName).hasLayer()) {
//No op
return;
}
int layerIndex = verticesMap.get(layerName).getVertexIndex();
//Need to do partial forward pass. Simply folowing the topological ordering won't be efficient, as we might
// end up doing forward pass on layers we don't need to.
//However, we can start with the topological order, and prune out any layers we don't need to do
LinkedList<Integer> partialTopoSort = new LinkedList<>();
Set<Integer> seenSoFar = new HashSet<>();
partialTopoSort.add(topologicalOrder[layerIndex]);
seenSoFar.add(topologicalOrder[layerIndex]);
for (int j = layerIndex - 1; j >= 0; j--) {
//Do we need to do forward pass on this GraphVertex?
//If it is input to any other layer we need, then yes. Otherwise: no
VertexIndices[] outputsTo = vertices[topologicalOrder[j]].getOutputVertices();
boolean needed = false;
for (VertexIndices vi : outputsTo) {
if (seenSoFar.contains(vi.getVertexIndex())) {
needed = true;
break;
}
}
if (needed) {
partialTopoSort.addFirst(topologicalOrder[j]);
seenSoFar.add(topologicalOrder[j]);
}
}
int[] fwdPassOrder = new int[partialTopoSort.size()];
int k = 0;
for (Integer g : partialTopoSort) fwdPassOrder[k++] = g;
GraphVertex gv = vertices[fwdPassOrder[fwdPassOrder.length - 1]];
Layer layer = gv.getLayer();
if (!iter.hasNext() && iter.resetSupported()) {
iter.reset();
}
while (iter.hasNext()) {
MultiDataSet multiDataSet = iter.next();
setInputs(multiDataSet.getFeatures());
for (int j = 0; j < fwdPassOrder.length - 1; j++) {
GraphVertex current = vertices[fwdPassOrder[j]];
if (current.isInputVertex()) {
VertexIndices[] inputsTo = current.getOutputVertices();
INDArray input = inputs[current.getVertexIndex()];
for (VertexIndices v : inputsTo) {
int vIdx = v.getVertexIndex();
int vIdxInputNum = v.getVertexEdgeNumber();
//This input: the 'vIdxInputNum'th input to vertex 'vIdx'
//TODO When to dup?
vertices[vIdx].setInput(vIdxInputNum, input.dup());
}
} else {
//Do forward pass:
INDArray out = current.doForward(true);
//Now, set the inputs for the next vertices:
VertexIndices[] outputsTo = current.getOutputVertices();
if (outputsTo != null) {
for (VertexIndices v : outputsTo) {
int vIdx = v.getVertexIndex();
int inputNum = v.getVertexEdgeNumber();
//This (jth) connection from the output: is the 'inputNum'th input to vertex 'vIdx'
vertices[vIdx].setInput(inputNum, out);
}
}
}
}
//At this point: have done all of the required forward pass stuff. Can now pretrain layer on current input
layer.fit(gv.getInputs()[0]);
layer.conf().setPretrain(false);
}
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ParallelWrapper method fit.
/**
*
* @param source
*/
public synchronized void fit(@NonNull MultiDataSetIterator source) {
stopFit.set(false);
if (zoo == null) {
zoo = new Trainer[workers];
for (int cnt = 0; cnt < workers; cnt++) {
// we pass true here, to tell Trainer to use MultiDataSet queue for training
zoo[cnt] = new Trainer(cnt, model, Nd4j.getAffinityManager().getDeviceForCurrentThread(), true);
zoo[cnt].setUncaughtExceptionHandler(handler);
zoo[cnt].start();
}
} else {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].useMDS = true;
}
}
source.reset();
MultiDataSetIterator iterator;
if (prefetchSize > 0 && source.asyncSupported()) {
iterator = new AsyncMultiDataSetIterator(source, prefetchSize);
} else
iterator = source;
AtomicInteger locker = new AtomicInteger(0);
while (iterator.hasNext() && !stopFit.get()) {
MultiDataSet dataSet = iterator.next();
if (dataSet == null)
throw new ND4JIllegalStateException("You can't have NULL as MultiDataSet");
/*
now dataSet should be dispatched to next free workers, until all workers are busy. And then we should block till all finished.
*/
int pos = locker.getAndIncrement();
zoo[pos].feedMultiDataSet(dataSet);
/*
if all workers are dispatched now, join till all are finished
*/
if (pos + 1 == workers || !iterator.hasNext()) {
iterationsCounter.incrementAndGet();
for (int cnt = 0; cnt < workers && cnt < locker.get(); cnt++) {
try {
zoo[cnt].waitTillRunning();
} catch (Exception e) {
throw new RuntimeException(e);
}
}
Nd4j.getMemoryManager().invokeGcOccasionally();
/*
average model, and propagate it to whole
*/
if (iterationsCounter.get() % averagingFrequency == 0 && pos + 1 == workers) {
double score = getScore(locker);
// averaging updaters state
if (model instanceof ComputationGraph) {
averageUpdatersState(locker, score);
} else
throw new RuntimeException("MultiDataSet must only be used with ComputationGraph model");
if (legacyAveraging && Nd4j.getAffinityManager().getNumberOfDevices() > 1) {
for (int cnt = 0; cnt < workers; cnt++) {
zoo[cnt].updateModel(model);
}
}
}
locker.set(0);
}
}
// sanity checks, or the dataset may never average
if (!wasAveraged)
log.warn("Parameters were never averaged on current fit(). Ratios of batch size, num workers, and averaging frequency may be responsible.");
// throw new IllegalStateException("Parameters were never averaged. Please check batch size ratios, number of workers, and your averaging frequency.");
log.debug("Iterations passed: {}", iterationsCounter.get());
// iterationsCounter.set(0);
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class MultiDataSetExportFunction method call.
@Override
public void call(Iterator<MultiDataSet> iter) throws Exception {
String jvmuid = UIDProvider.getJVMUID();
uid = Thread.currentThread().getId() + jvmuid.substring(0, Math.min(8, jvmuid.length()));
while (iter.hasNext()) {
MultiDataSet next = iter.next();
String filename = "mds_" + uid + "_" + (outputCount++) + ".bin";
String path = outputDir.getPath();
URI uri = new URI(path + (path.endsWith("/") || path.endsWith("\\") ? "" : "/") + filename);
FileSystem file = FileSystem.get(uri, conf);
try (FSDataOutputStream out = file.create(new Path(uri))) {
next.save(out);
}
}
}
Aggregations