use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class BatchAndExportMultiDataSetsFunction method call.
@Override
public Iterator<String> call(Integer partitionIdx, Iterator<MultiDataSet> iterator) throws Exception {
List<String> outputPaths = new ArrayList<>();
LinkedList<MultiDataSet> tempList = new LinkedList<>();
int count = 0;
while (iterator.hasNext()) {
MultiDataSet next = iterator.next();
if (next.getFeatures(0).size(0) == minibatchSize) {
outputPaths.add(export(next, partitionIdx, count++));
continue;
}
//DataSet must be either smaller or larger than minibatch size...
tempList.add(next);
Pair<Integer, List<String>> countAndPaths = processList(tempList, partitionIdx, count, false);
if (countAndPaths.getSecond() != null && countAndPaths.getSecond().size() > 0) {
outputPaths.addAll(countAndPaths.getSecond());
}
count = countAndPaths.getFirst();
}
//We might have some left-over examples...
Pair<Integer, List<String>> countAndPaths = processList(tempList, partitionIdx, count, true);
if (countAndPaths.getSecond() != null && countAndPaths.getSecond().size() > 0) {
outputPaths.addAll(countAndPaths.getSecond());
}
return outputPaths.iterator();
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class PortableDataStreamMultiDataSetIterator method next.
@Override
public MultiDataSet next() {
MultiDataSet ds = new org.nd4j.linalg.dataset.MultiDataSet();
PortableDataStream pds = iter.next();
try (InputStream is = pds.open()) {
ds.load(is);
} catch (IOException e) {
throw new RuntimeException(e);
}
if (preprocessor != null)
preprocessor.preProcess(ds);
return ds;
}
use of org.nd4j.linalg.dataset.api.MultiDataSet in project deeplearning4j by deeplearning4j.
the class ScoreExamplesWithKeyFunctionAdapter method call.
@Override
public Iterable<Tuple2<K, Double>> call(Iterator<Tuple2<K, MultiDataSet>> iterator) throws Exception {
if (!iterator.hasNext()) {
return Collections.emptyList();
}
ComputationGraph network = new ComputationGraph(ComputationGraphConfiguration.fromJson(jsonConfig.getValue()));
network.init();
INDArray val = params.value().unsafeDuplication();
if (val.length() != network.numParams(false))
throw new IllegalStateException("Network did not have same number of parameters as the broadcast set parameters");
network.setParams(val);
List<Tuple2<K, Double>> ret = new ArrayList<>();
List<MultiDataSet> collect = new ArrayList<>(batchSize);
List<K> collectKey = new ArrayList<>(batchSize);
int totalCount = 0;
while (iterator.hasNext()) {
collect.clear();
collectKey.clear();
int nExamples = 0;
while (iterator.hasNext() && nExamples < batchSize) {
Tuple2<K, MultiDataSet> t2 = iterator.next();
MultiDataSet ds = t2._2();
int n = ds.getFeatures(0).size(0);
if (n != 1)
throw new IllegalStateException("Cannot score examples with one key per data set if " + "data set contains more than 1 example (numExamples: " + n + ")");
collect.add(ds);
collectKey.add(t2._1());
nExamples += n;
}
totalCount += nExamples;
MultiDataSet data = org.nd4j.linalg.dataset.MultiDataSet.merge(collect);
INDArray scores = network.scoreExamples(data, addRegularization);
double[] doubleScores = scores.data().asDouble();
for (int i = 0; i < doubleScores.length; i++) {
ret.add(new Tuple2<>(collectKey.get(i), doubleScores[i]));
}
}
if (Nd4j.getExecutioner() instanceof GridExecutioner)
((GridExecutioner) Nd4j.getExecutioner()).flushQueueBlocking();
if (log.isDebugEnabled()) {
log.debug("Scored {} examples ", totalCount);
}
return ret;
}
Aggregations