use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class TestDataSetIterator method next.
@Override
public DataSet next() {
numDataSets++;
DataSet next = wrapped.next();
if (preProcessor != null)
preProcessor.preProcess(next);
return next;
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class MagicQueue method poll.
/**
* This method is supposed to be called from managed thread, attached to specific device.
* It returns 1 DataSet element from head of the queue, and deletes that element from Queue.
* If queue is empty,
*
* Please note: if there's nothing available in Queue - NULL will be returned
* @param time time to wait for something appear in queue
* @param timeUnit TimeUnit for time param
* @return
*/
public DataSet poll(long time, TimeUnit timeUnit) throws InterruptedException {
if (mode == Mode.THREADED) {
if (numberOfBuckets > 1) {
int deviceId = Nd4j.getAffinityManager().getDeviceForCurrentThread();
DataSet ds = backingQueues.get(deviceId).poll(time, timeUnit);
if (ds != null)
cntGet.incrementAndGet();
return ds;
} else {
DataSet ds = backingQueues.get(0).poll(time, timeUnit);
if (ds != null)
cntGet.incrementAndGet();
return ds;
}
} else {
//log.info("Trying queue_{}; queue_0: {}; queue_1: {}", interleavedCounter.get(), backingQueues.get(0).size(), backingQueues.get(1).size());
DataSet ds = backingQueues.get(interleavedCounter.getAndIncrement()).poll(time, timeUnit);
if (interleavedCounter.get() >= backingQueues.size())
interleavedCounter.set(0);
if (ds != null)
cntGet.incrementAndGet();
return ds;
}
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetIterator method getDataSet.
private DataSet getDataSet(List<Writable> record) {
List<Writable> currList;
if (record instanceof List)
currList = record;
else
currList = new ArrayList<>(record);
//allow people to specify label index as -1 and infer the last possible label
if (numPossibleLabels >= 1 && labelIndex < 0) {
labelIndex = record.size() - 1;
}
INDArray label = null;
INDArray featureVector = null;
int featureCount = 0;
int labelCount = 0;
//no labels
if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
NDArrayWritable writable = (NDArrayWritable) currList.get(0);
return new DataSet(writable.get(), writable.get());
}
if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
if (!regression) {
label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()), numPossibleLabels);
} else {
if (currList.get(1) instanceof NDArrayWritable) {
label = ((NDArrayWritable) currList.get(1)).get();
} else {
label = Nd4j.scalar(currList.get(1).toDouble());
}
}
NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0);
featureVector = ndArrayWritable.get();
return new DataSet(featureVector, label);
}
for (int j = 0; j < currList.size(); j++) {
Writable current = currList.get(j);
//ndarray writable is an insane slow down herecd
if (!(current instanceof NDArrayWritable) && current.toString().isEmpty())
continue;
if (regression && j == labelIndex && j == labelIndexTo && current instanceof NDArrayWritable) {
//Case: NDArrayWritable for the labels
label = ((NDArrayWritable) current).get();
} else if (regression && j >= labelIndex && j <= labelIndexTo) {
//This is the multi-label regression case
if (label == null)
label = Nd4j.create(1, (labelIndexTo - labelIndex + 1));
label.putScalar(labelCount++, current.toDouble());
} else if (labelIndex >= 0 && j == labelIndex) {
//single label case (classification, etc)
if (converter != null)
try {
current = converter.convert(current);
} catch (WritableConverterException e) {
e.printStackTrace();
}
if (numPossibleLabels < 1)
throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
if (regression) {
label = Nd4j.scalar(current.toDouble());
} else {
int curr = current.toInt();
if (curr < 0 || curr >= numPossibleLabels) {
throw new DL4JInvalidInputException("Invalid classification data: expect label value (at label index column = " + labelIndex + ") to be in range 0 to " + (numPossibleLabels - 1) + " inclusive (0 to numClasses-1, with numClasses=" + numPossibleLabels + "); got label value of " + current);
}
label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
}
} else {
try {
double value = current.toDouble();
if (featureVector == null) {
if (regression && labelIndex >= 0) {
//Handle the possibly multi-label regression case here:
int nLabels = labelIndexTo - labelIndex + 1;
featureVector = Nd4j.create(1, currList.size() - nLabels);
} else {
//Classification case, and also no-labels case
featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
}
}
featureVector.putScalar(featureCount++, value);
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
assert featureVector == null;
featureVector = ((NDArrayWritable) current).get();
} else {
throw e;
}
}
}
}
return new DataSet(featureVector, labelIndex >= 0 ? label : featureVector);
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetIterator method totalOutcomes.
@Override
public int totalOutcomes() {
if (last == null) {
DataSet next = next();
last = next;
useCurrent = true;
return next.numOutcomes();
} else
return last.numOutcomes();
}
use of org.nd4j.linalg.dataset.DataSet in project deeplearning4j by deeplearning4j.
the class RecordReaderDataSetIterator method next.
@Override
public DataSet next(int num) {
if (useCurrent) {
useCurrent = false;
if (preProcessor != null)
preProcessor.preProcess(last);
return last;
}
List<DataSet> dataSets = new ArrayList<>();
List<RecordMetaData> meta = (collectMetaData ? new ArrayList<RecordMetaData>() : null);
for (int i = 0; i < num; i++) {
if (!hasNext())
break;
if (recordReader instanceof SequenceRecordReader) {
if (sequenceIter == null || !sequenceIter.hasNext()) {
List<List<Writable>> sequenceRecord = ((SequenceRecordReader) recordReader).sequenceRecord();
sequenceIter = sequenceRecord.iterator();
}
List<Writable> record = sequenceIter.next();
dataSets.add(getDataSet(record));
} else {
if (collectMetaData) {
Record record = recordReader.nextRecord();
dataSets.add(getDataSet(record.getRecord()));
meta.add(record.getMetaData());
} else {
List<Writable> record = recordReader.next();
dataSets.add(getDataSet(record));
}
}
}
batchNum++;
if (dataSets.isEmpty())
return new DataSet();
DataSet ret = DataSet.merge(dataSets);
if (collectMetaData) {
ret.setExampleMetaData(meta);
}
last = ret;
if (preProcessor != null)
preProcessor.preProcess(ret);
//Add label name values to dataset
if (recordReader.getLabels() != null)
ret.setLabelNames(recordReader.getLabels());
return ret;
}
Aggregations