use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.
the class NDArrayRecordToNDArray method convert.
@Override
public INDArray convert(Collection<Collection<Writable>> records) {
INDArray[] concat = new INDArray[records.size()];
int count = 0;
for (Collection<Writable> record : records) {
NDArrayWritable writable = (NDArrayWritable) record.iterator().next();
concat[count++] = writable.get();
}
return Nd4j.concat(0, concat);
}
use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.
the class DataVecDataSetFunction method call.
@Override
public DataSet call(List<Writable> currList) throws Exception {
//allow people to specify label index as -1 and infer the last possible label
int labelIndex = this.labelIndex;
if (numPossibleLabels >= 1 && labelIndex < 0) {
labelIndex = currList.size() - 1;
}
INDArray label = null;
INDArray featureVector = null;
int featureCount = 0;
int labelCount = 0;
//no labels
if (currList.size() == 2 && currList.get(1) instanceof NDArrayWritable && currList.get(0) instanceof NDArrayWritable && currList.get(0) == currList.get(1)) {
NDArrayWritable writable = (NDArrayWritable) currList.get(0);
DataSet ds = new DataSet(writable.get(), writable.get());
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
if (currList.size() == 2 && currList.get(0) instanceof NDArrayWritable) {
if (!regression)
label = FeatureUtil.toOutcomeVector((int) Double.parseDouble(currList.get(1).toString()), numPossibleLabels);
else
label = Nd4j.scalar(Double.parseDouble(currList.get(1).toString()));
NDArrayWritable ndArrayWritable = (NDArrayWritable) currList.get(0);
featureVector = ndArrayWritable.get();
DataSet ds = new DataSet(featureVector, label);
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
for (int j = 0; j < currList.size(); j++) {
Writable current = currList.get(j);
//ndarray writable is an insane slow down here
if (!(current instanceof NDArrayWritable) && current.toString().isEmpty())
continue;
if (labelIndex >= 0 && j >= labelIndex && j <= labelIndexTo) {
//single label case (classification, single label regression etc)
if (converter != null) {
try {
current = converter.convert(current);
} catch (WritableConverterException e) {
e.printStackTrace();
}
}
if (regression) {
//single and multi-label regression
if (label == null) {
label = Nd4j.zeros(labelIndexTo - labelIndex + 1);
}
label.putScalar(0, labelCount++, current.toDouble());
} else {
if (numPossibleLabels < 1)
throw new IllegalStateException("Number of possible labels invalid, must be >= 1 for classification");
int curr = current.toInt();
if (curr >= numPossibleLabels)
throw new IllegalStateException("Invalid index: got index " + curr + " but numPossibleLabels is " + numPossibleLabels + " (must be 0 <= idx < numPossibleLabels");
label = FeatureUtil.toOutcomeVector(curr, numPossibleLabels);
}
} else {
try {
double value = current.toDouble();
if (featureVector == null) {
if (regression && labelIndex >= 0) {
//Handle the possibly multi-label regression case here:
int nLabels = labelIndexTo - labelIndex + 1;
featureVector = Nd4j.create(1, currList.size() - nLabels);
} else {
//Classification case, and also no-labels case
featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
}
}
featureVector.putScalar(featureCount++, value);
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
assert featureVector == null;
featureVector = ((NDArrayWritable) current).get();
} else {
throw e;
}
}
}
}
DataSet ds = new DataSet(featureVector, (labelIndex >= 0 ? label : featureVector));
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.
the class DataVecSequencePairDataSetFunction method call.
@Override
public DataSet call(Tuple2<List<List<Writable>>, List<List<Writable>>> input) throws Exception {
List<List<Writable>> featuresSeq = input._1();
List<List<Writable>> labelsSeq = input._2();
int featuresLength = featuresSeq.size();
int labelsLength = labelsSeq.size();
Iterator<List<Writable>> fIter = featuresSeq.iterator();
Iterator<List<Writable>> lIter = labelsSeq.iterator();
INDArray inputArr = null;
INDArray outputArr = null;
int[] idx = new int[3];
int i = 0;
while (fIter.hasNext()) {
List<Writable> step = fIter.next();
if (i == 0) {
int[] inShape = new int[] { 1, step.size(), featuresLength };
inputArr = Nd4j.create(inShape);
}
Iterator<Writable> timeStepIter = step.iterator();
int f = 0;
idx[1] = 0;
while (timeStepIter.hasNext()) {
Writable current = timeStepIter.next();
if (converter != null)
current = converter.convert(current);
try {
inputArr.putScalar(idx, current.toDouble());
} catch (UnsupportedOperationException e) {
// This isn't a scalar, so check if we got an array already
if (current instanceof NDArrayWritable) {
inputArr.get(NDArrayIndex.point(idx[0]), NDArrayIndex.all(), NDArrayIndex.point(idx[2])).putRow(0, ((NDArrayWritable) current).get());
} else {
throw e;
}
}
idx[1] = ++f;
}
idx[2] = ++i;
}
idx = new int[3];
i = 0;
while (lIter.hasNext()) {
List<Writable> step = lIter.next();
if (i == 0) {
int[] outShape = new int[] { 1, (regression ? step.size() : numPossibleLabels), labelsLength };
outputArr = Nd4j.create(outShape);
}
Iterator<Writable> timeStepIter = step.iterator();
int f = 0;
idx[1] = 0;
if (regression) {
//Load all values without modification
while (timeStepIter.hasNext()) {
Writable current = timeStepIter.next();
if (converter != null)
current = converter.convert(current);
outputArr.putScalar(idx, current.toDouble());
idx[1] = ++f;
}
} else {
//Expect a single value (index) -> convert to one-hot vector
Writable value = timeStepIter.next();
int labelClassIdx = value.toInt();
INDArray line = FeatureUtil.toOutcomeVector(labelClassIdx, numPossibleLabels);
//1d from [1,nOut,timeSeriesLength] -> tensor i along dimension 1 is at time i
outputArr.tensorAlongDimension(i, 1).assign(line);
}
idx[2] = ++i;
}
DataSet ds;
if (alignmentMode == AlignmentMode.EQUAL_LENGTH || featuresLength == labelsLength) {
ds = new DataSet(inputArr, outputArr);
} else if (alignmentMode == AlignmentMode.ALIGN_END) {
if (featuresLength > labelsLength) {
//Input longer, pad output
INDArray newOutput = Nd4j.create(1, outputArr.size(1), featuresLength);
newOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(featuresLength - labelsLength, featuresLength)).assign(outputArr);
//Need an output mask array, but not an input mask array
INDArray outputMask = Nd4j.create(1, featuresLength);
for (int j = featuresLength - labelsLength; j < featuresLength; j++) outputMask.putScalar(j, 1.0);
ds = new DataSet(inputArr, newOutput, Nd4j.ones(outputMask.shape()), outputMask);
} else {
//Output longer, pad input
INDArray newInput = Nd4j.create(1, inputArr.size(1), labelsLength);
newInput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(labelsLength - featuresLength, labelsLength)).assign(inputArr);
//Need an input mask array, but not an output mask array
INDArray inputMask = Nd4j.create(1, labelsLength);
for (int j = labelsLength - featuresLength; j < labelsLength; j++) inputMask.putScalar(j, 1.0);
ds = new DataSet(newInput, outputArr, inputMask, Nd4j.ones(inputMask.shape()));
}
} else if (alignmentMode == AlignmentMode.ALIGN_START) {
if (featuresLength > labelsLength) {
//Input longer, pad output
INDArray newOutput = Nd4j.create(1, outputArr.size(1), featuresLength);
newOutput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, labelsLength)).assign(outputArr);
//Need an output mask array, but not an input mask array
INDArray outputMask = Nd4j.create(1, featuresLength);
for (int j = 0; j < labelsLength; j++) outputMask.putScalar(j, 1.0);
ds = new DataSet(inputArr, newOutput, Nd4j.ones(outputMask.shape()), outputMask);
} else {
//Output longer, pad input
INDArray newInput = Nd4j.create(1, inputArr.size(1), labelsLength);
newInput.get(NDArrayIndex.point(0), NDArrayIndex.all(), NDArrayIndex.interval(0, featuresLength)).assign(inputArr);
//Need an input mask array, but not an output mask array
INDArray inputMask = Nd4j.create(1, labelsLength);
for (int j = 0; j < featuresLength; j++) inputMask.putScalar(j, 1.0);
ds = new DataSet(newInput, outputArr, inputMask, Nd4j.ones(inputMask.shape()));
}
} else {
throw new UnsupportedOperationException("Invalid alignment mode: " + alignmentMode);
}
if (preProcessor != null)
preProcessor.preProcess(ds);
return ds;
}
use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.
the class RecordReaderFunction method call.
@Override
public DataSet call(String v1) throws Exception {
recordReader.initialize(new StringSplit(v1));
List<DataSet> dataSets = new ArrayList<>();
List<Writable> record = recordReader.next();
List<Writable> currList;
if (record instanceof List)
currList = (List<Writable>) record;
else
currList = new ArrayList<>(record);
INDArray label = null;
INDArray featureVector = Nd4j.create(labelIndex >= 0 ? currList.size() - 1 : currList.size());
int count = 0;
for (int j = 0; j < currList.size(); j++) {
if (labelIndex >= 0 && j == labelIndex) {
if (numPossibleLabels < 1)
throw new IllegalStateException("Number of possible labels invalid, must be >= 1");
Writable current = currList.get(j);
if (converter != null)
current = converter.convert(current);
label = FeatureUtil.toOutcomeVector(Double.valueOf(current.toString()).intValue(), numPossibleLabels);
} else {
Writable current = currList.get(j);
featureVector.putScalar(count++, Double.valueOf(current.toString()));
}
}
dataSets.add(new DataSet(featureVector, labelIndex >= 0 ? label : featureVector));
List<INDArray> inputs = new ArrayList<>();
List<INDArray> labels = new ArrayList<>();
for (DataSet data : dataSets) {
inputs.add(data.getFeatureMatrix());
labels.add(data.getLabels());
}
DataSet ret = new DataSet(Nd4j.vstack(inputs.toArray(new INDArray[0])), Nd4j.vstack(labels.toArray(new INDArray[0])));
return ret;
}
use of org.datavec.api.writable.Writable in project deeplearning4j by deeplearning4j.
the class TestDataVecDataSetFunctions method testDataVecSequenceDataSetFunction.
@Test
public void testDataVecSequenceDataSetFunction() throws Exception {
JavaSparkContext sc = getContext();
//Test Spark record reader functionality vs. local
File f = new File("src/test/resources/csvsequence/csvsequence_0.txt");
String path = f.getPath();
String folder = path.substring(0, path.length() - 17);
path = folder + "*";
JavaPairRDD<String, PortableDataStream> origData = sc.binaryFiles(path);
//3 CSV sequences
assertEquals(3, origData.count());
SequenceRecordReader seqRR = new CSVSequenceRecordReader(1, ",");
SequenceRecordReaderFunction rrf = new SequenceRecordReaderFunction(seqRR);
JavaRDD<List<List<Writable>>> rdd = origData.map(rrf);
JavaRDD<DataSet> data = rdd.map(new DataVecSequenceDataSetFunction(2, -1, true, null, null));
List<DataSet> collected = data.collect();
//Load normally (i.e., not via Spark), and check that we get the same results (order not withstanding)
InputSplit is = new FileSplit(new File(folder), new String[] { "txt" }, true);
SequenceRecordReader seqRR2 = new CSVSequenceRecordReader(1, ",");
seqRR2.initialize(is);
SequenceRecordReaderDataSetIterator iter = new SequenceRecordReaderDataSetIterator(seqRR2, 1, -1, 2, true);
List<DataSet> listLocal = new ArrayList<>(3);
while (iter.hasNext()) {
listLocal.add(iter.next());
}
//Compare:
assertEquals(3, collected.size());
assertEquals(3, listLocal.size());
//Check that results are the same (order not withstanding)
boolean[] found = new boolean[3];
for (int i = 0; i < 3; i++) {
int foundIndex = -1;
DataSet ds = collected.get(i);
for (int j = 0; j < 3; j++) {
if (ds.equals(listLocal.get(j))) {
if (foundIndex != -1)
//Already found this value -> suggests this spark value equals two or more of local version? (Shouldn't happen)
fail();
foundIndex = j;
if (found[foundIndex])
//One of the other spark values was equal to this one -> suggests duplicates in Spark list
fail();
//mark this one as seen before
found[foundIndex] = true;
}
}
}
int count = 0;
for (boolean b : found) if (b)
count++;
//Expect all 3 and exactly 3 pairwise matches between spark and local versions
assertEquals(3, count);
}
Aggregations